/*
 * Decompiled with CFR 0.152.
 */
package jadx.core.dex.visitors.regions.maker;

import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.attributes.nodes.LoopInfo;
import jadx.core.dex.attributes.nodes.RegionRefAttr;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.SwitchInsn;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.IRegion;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.regions.Region;
import jadx.core.dex.regions.SwitchRegion;
import jadx.core.dex.visitors.regions.maker.RegionMaker;
import jadx.core.dex.visitors.regions.maker.RegionStack;
import jadx.core.utils.BlockUtils;
import jadx.core.utils.RegionUtils;
import jadx.core.utils.Utils;
import jadx.core.utils.exceptions.JadxRuntimeException;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.jetbrains.annotations.Nullable;

final class SwitchRegionMaker {
    private final MethodNode mth;
    private final RegionMaker regionMaker;

    SwitchRegionMaker(MethodNode mth, RegionMaker regionMaker) {
        this.mth = mth;
        this.regionMaker = regionMaker;
    }

    BlockNode process(IRegion currentRegion, BlockNode block, SwitchInsn insn, RegionStack stack) {
        List keys;
        int len = insn.getTargets().length;
        LinkedHashMap<BlockNode, List<Object>> blocksMap = new LinkedHashMap<BlockNode, List<Object>>(len);
        BlockNode[] targetBlocksArr = insn.getTargetBlocks();
        for (int i = 0; i < len; ++i) {
            keys = blocksMap.computeIfAbsent(targetBlocksArr[i], k -> new ArrayList(2));
            keys.add(insn.getKey(i));
        }
        BlockNode defCase = insn.getDefTargetBlock();
        if (defCase != null) {
            keys = blocksMap.computeIfAbsent(defCase, k -> new ArrayList(1));
            keys.add(SwitchRegion.DEFAULT_CASE_KEY);
        }
        SwitchRegion sw = new SwitchRegion(currentRegion, block);
        insn.addAttr(new RegionRefAttr(sw));
        currentRegion.getSubBlocks().add(sw);
        stack.push(sw);
        BlockNode out = this.calcSwitchOut(block, stack);
        stack.addExit(out);
        this.processFallThroughCases(sw, out, stack, blocksMap);
        this.removeEmptyCases(insn, sw, defCase);
        stack.pop();
        return out;
    }

    private void processFallThroughCases(SwitchRegion sw, @Nullable BlockNode out, RegionStack stack, Map<BlockNode, List<Object>> blocksMap) {
        LinkedHashMap<BlockNode, BlockNode> fallThroughCases = new LinkedHashMap<BlockNode, BlockNode>();
        if (out != null) {
            BitSet caseBlocks = BlockUtils.blocksToBitSet(this.mth, blocksMap.keySet());
            caseBlocks.clear(out.getId());
            for (BlockNode successor : sw.getHeader().getCleanSuccessors()) {
                BitSet df = successor.getDomFrontier();
                if (!df.intersects(caseBlocks)) continue;
                BlockNode fallThroughBlock = this.getOneIntersectionBlock(out, caseBlocks, df);
                fallThroughCases.put(successor, fallThroughBlock);
            }
            if (!fallThroughCases.isEmpty() && this.isBadCasesOrder(blocksMap, fallThroughCases)) {
                Map<BlockNode, List<Object>> newBlocksMap = this.reOrderSwitchCases(blocksMap, fallThroughCases);
                if (this.isBadCasesOrder(newBlocksMap, fallThroughCases)) {
                    this.mth.addWarnComment("Can't fix incorrect switch cases order, some code will duplicate");
                    fallThroughCases.clear();
                } else {
                    blocksMap = newBlocksMap;
                }
            }
        }
        for (Map.Entry<BlockNode, List<Object>> entry : blocksMap.entrySet()) {
            List<Object> keysList = entry.getValue();
            BlockNode caseBlock = entry.getKey();
            if (stack.containsExit(caseBlock)) {
                sw.addCase(keysList, new Region(stack.peekRegion()));
                continue;
            }
            BlockNode next = (BlockNode)fallThroughCases.get(caseBlock);
            stack.addExit(next);
            Region caseRegion = this.regionMaker.makeRegion(caseBlock);
            stack.removeExit(next);
            if (next != null) {
                next.add(AFlag.FALL_THROUGH);
                caseRegion.add(AFlag.FALL_THROUGH);
            }
            sw.addCase(keysList, caseRegion);
        }
    }

    @Nullable
    private BlockNode getOneIntersectionBlock(BlockNode out, BitSet caseBlocks, BitSet fallThroughSet) {
        BitSet caseExits = BlockUtils.copyBlocksBitSet(this.mth, fallThroughSet);
        caseExits.clear(out.getId());
        caseExits.and(caseBlocks);
        return BlockUtils.bitSetToOneBlock(this.mth, caseExits);
    }

    @Nullable
    private BlockNode calcSwitchOut(BlockNode block, RegionStack stack) {
        BitSet outs = BlockUtils.newBlocksBitSet(this.mth);
        for (BlockNode s : block.getCleanSuccessors()) {
            if (s.contains(AFlag.LOOP_END)) continue;
            outs.or(s.getDomFrontier());
        }
        outs.clear(block.getId());
        outs.clear(this.mth.getExitBlock().getId());
        BlockNode out = null;
        if (outs.cardinality() == 1) {
            out = BlockUtils.bitSetToOneBlock(this.mth, outs);
        } else {
            LoopInfo loop = this.mth.getLoopForBlock(block);
            if (loop != null) {
                outs.andNot(loop.getStart().getPostDoms());
                outs.andNot(loop.getEnd().getPostDoms());
                BlockNode loopEnd = loop.getEnd();
                if (outs.cardinality() == 2 && outs.get(loopEnd.getId())) {
                    List<BlockNode> outList = BlockUtils.bitSetToBlocks(this.mth, outs);
                    outList.remove(loopEnd);
                    BlockNode possibleOut = Utils.getOne(outList);
                    if (possibleOut != null && this.insertContinueInSwitch(block, possibleOut, loopEnd)) {
                        outs.clear(loopEnd.getId());
                        out = possibleOut;
                    }
                }
                if (outs.isEmpty()) {
                    return this.mth.getExitBlock();
                }
            }
            if (out == null) {
                BlockNode imPostDom = block.getIPostDom();
                if (outs.get(imPostDom.getId())) {
                    out = imPostDom;
                } else {
                    outs.andNot(block.getPostDoms());
                    out = BlockUtils.bitSetToOneBlock(this.mth, outs);
                }
            }
        }
        if (out != null && this.mth.isPreExitBlock(out)) {
            out = this.mth.getExitBlock();
        }
        BlockNode imPostDom = block.getIPostDom();
        if (out == null && imPostDom == this.mth.getExitBlock()) {
            return this.allSameReturns(stack);
        }
        if (out != imPostDom && !this.mth.isPreExitBlock(imPostDom)) {
            stack.addExit(imPostDom);
        }
        if (block.getCleanSuccessors().contains(imPostDom)) {
            stack.addExit(imPostDom);
        }
        if (out == null) {
            this.mth.addWarnComment("Failed to find 'out' block for switch in " + block + ". Please report as an issue.");
            out = block.getIPostDom();
        }
        if (out != null && this.regionMaker.isProcessed(out)) {
            throw new JadxRuntimeException("Failed to find switch 'out' block (already processed)");
        }
        return out;
    }

    private BlockNode allSameReturns(RegionStack stack) {
        BlockNode exitBlock = this.mth.getExitBlock();
        List<BlockNode> preds = exitBlock.getPredecessors();
        int count = preds.size();
        if (count == 1) {
            return preds.get(0);
        }
        if (this.mth.getReturnType() == ArgType.VOID) {
            for (BlockNode blockNode : preds) {
                InsnNode insn = BlockUtils.getLastInsn(blockNode);
                if (insn != null && insn.getType() == InsnType.RETURN) continue;
                return exitBlock;
            }
        } else {
            ArrayList<InsnArg> returnArgs = new ArrayList<InsnArg>();
            for (BlockNode pred : preds) {
                InsnNode insn = BlockUtils.getLastInsn(pred);
                if (insn == null || insn.getType() != InsnType.RETURN) {
                    return exitBlock;
                }
                returnArgs.add(insn.getArg(0));
            }
            InsnArg insnArg = (InsnArg)returnArgs.get(0);
            if (insnArg.isRegister()) {
                RegisterArg reg = (RegisterArg)insnArg;
                for (int i = 1; i < count; ++i) {
                    InsnArg arg = (InsnArg)returnArgs.get(1);
                    if (arg.isRegister() && ((RegisterArg)arg).sameCodeVar(reg)) continue;
                    return exitBlock;
                }
            } else {
                for (int i = 1; i < count; ++i) {
                    InsnArg arg = (InsnArg)returnArgs.get(1);
                    if (arg.equals(insnArg)) continue;
                    return exitBlock;
                }
            }
        }
        stack.addExits(preds);
        for (int i = 1; i < count; ++i) {
            BlockNode blockNode = preds.get(i);
            blockNode.add(AFlag.REMOVE);
            blockNode.add(AFlag.ADDED_TO_REGION);
        }
        return preds.get(0);
    }

    private void removeEmptyCases(SwitchInsn insn, SwitchRegion sw, BlockNode defCase) {
        boolean defaultCaseIsEmpty = defCase == null ? true : sw.getCases().stream().anyMatch(c -> c.getKeys().contains(SwitchRegion.DEFAULT_CASE_KEY) && RegionUtils.isEmpty(c.getContainer()));
        if (defaultCaseIsEmpty) {
            sw.getCases().removeIf(caseInfo -> {
                if (RegionUtils.isEmpty(caseInfo.getContainer())) {
                    List<Object> keys = caseInfo.getKeys();
                    if (keys.contains(SwitchRegion.DEFAULT_CASE_KEY)) {
                        return true;
                    }
                    if (insn.isPacked()) {
                        return true;
                    }
                }
                return false;
            });
        }
    }

    private boolean isBadCasesOrder(Map<BlockNode, List<Object>> blocksMap, Map<BlockNode, BlockNode> fallThroughCases) {
        BlockNode nextCaseBlock = null;
        for (BlockNode caseBlock : blocksMap.keySet()) {
            if (nextCaseBlock != null && !caseBlock.equals(nextCaseBlock)) {
                return true;
            }
            nextCaseBlock = fallThroughCases.get(caseBlock);
        }
        return nextCaseBlock != null;
    }

    private Map<BlockNode, List<Object>> reOrderSwitchCases(Map<BlockNode, List<Object>> blocksMap, Map<BlockNode, BlockNode> fallThroughCases) {
        ArrayList<BlockNode> list = new ArrayList<BlockNode>(blocksMap.size());
        list.addAll(blocksMap.keySet());
        list.sort((a, b) -> {
            BlockNode nextA = (BlockNode)fallThroughCases.get(a);
            if (nextA != null) {
                if (b.equals(nextA)) {
                    return -1;
                }
            } else if (a.equals(fallThroughCases.get(b))) {
                return 1;
            }
            return 0;
        });
        LinkedHashMap<BlockNode, List<Object>> newBlocksMap = new LinkedHashMap<BlockNode, List<Object>>(blocksMap.size());
        for (BlockNode key : list) {
            newBlocksMap.put(key, blocksMap.get(key));
        }
        return newBlocksMap;
    }

    private boolean insertContinueInSwitch(BlockNode switchBlock, BlockNode switchOut, BlockNode loopEnd) {
        boolean inserted = false;
        block0: for (BlockNode caseBlock : switchBlock.getCleanSuccessors()) {
            HashSet<BlockNode> list;
            if (!caseBlock.getDomFrontier().get(loopEnd.getId()) || caseBlock == switchOut || (list = new HashSet<BlockNode>(BlockUtils.collectBlocksDominatedBy(this.mth, caseBlock, caseBlock))).contains(switchOut)) continue;
            if (switchOut.getPredecessors().stream().anyMatch(list::contains)) continue;
            for (BlockNode p : loopEnd.getPredecessors()) {
                if (!list.contains(p)) continue;
                if (!p.isSynthetic()) continue block0;
                p.getInstructions().add(new InsnNode(InsnType.CONTINUE, 0));
                inserted = true;
                continue block0;
            }
        }
        return inserted;
    }
}

