/*
 * Decompiled with CFR 0.152.
 */
package jadx.core.dex.visitors.regions;

import jadx.api.plugins.input.data.annotations.EncodedType;
import jadx.api.plugins.input.data.annotations.EncodedValue;
import jadx.api.plugins.input.data.attributes.JadxAttrType;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.attributes.IAttributeNode;
import jadx.core.dex.attributes.nodes.CodeFeaturesAttr;
import jadx.core.dex.info.MethodInfo;
import jadx.core.dex.instructions.IfNode;
import jadx.core.dex.instructions.IfOp;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.InvokeNode;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.InsnWrapArg;
import jadx.core.dex.instructions.args.LiteralArg;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.instructions.args.SSAVar;
import jadx.core.dex.nodes.FieldNode;
import jadx.core.dex.nodes.IContainer;
import jadx.core.dex.nodes.IRegion;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.regions.SwitchRegion;
import jadx.core.dex.regions.conditions.IfCondition;
import jadx.core.dex.regions.conditions.IfRegion;
import jadx.core.dex.visitors.AbstractVisitor;
import jadx.core.dex.visitors.JadxVisitor;
import jadx.core.dex.visitors.regions.DepthRegionTraversal;
import jadx.core.dex.visitors.regions.IRegionIterativeVisitor;
import jadx.core.dex.visitors.regions.IfRegionVisitor;
import jadx.core.dex.visitors.regions.ReturnVisitor;
import jadx.core.utils.BlockUtils;
import jadx.core.utils.InsnRemover;
import jadx.core.utils.InsnUtils;
import jadx.core.utils.RegionUtils;
import jadx.core.utils.exceptions.JadxException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;
import org.jetbrains.annotations.Nullable;

@JadxVisitor(name="SwitchOverStringVisitor", desc="Restore switch over string", runAfter={IfRegionVisitor.class}, runBefore={ReturnVisitor.class})
public class SwitchOverStringVisitor
extends AbstractVisitor
implements IRegionIterativeVisitor {
    @Override
    public void visit(MethodNode mth) throws JadxException {
        if (!CodeFeaturesAttr.contains(mth, CodeFeaturesAttr.CodeFeature.SWITCH)) {
            return;
        }
        DepthRegionTraversal.traverseIterative(mth, this);
    }

    @Override
    public boolean visitRegion(MethodNode mth, IRegion region) {
        if (region instanceof SwitchRegion) {
            return this.restoreSwitchOverString(mth, (SwitchRegion)region);
        }
        return false;
    }

    private boolean restoreSwitchOverString(MethodNode mth, SwitchRegion switchRegion) {
        try {
            InsnNode swInsn = BlockUtils.getLastInsnWithType(switchRegion.getHeader(), InsnType.SWITCH);
            if (swInsn == null) {
                return false;
            }
            RegisterArg strArg = this.getStrHashCodeArg(swInsn.getArg(0));
            if (strArg == null) {
                return false;
            }
            int casesCount = switchRegion.getCases().size();
            SSAVar strVar = strArg.getSVar();
            if (strVar.getUseCount() - 1 < casesCount) {
                return false;
            }
            Map<InsnNode, String> strEqInsns = SwitchOverStringVisitor.collectEqualsInsns(mth, strVar);
            if (strEqInsns.size() < casesCount) {
                return false;
            }
            SwitchData switchData = new SwitchData(mth, switchRegion);
            switchData.setStrEqInsns(strEqInsns);
            switchData.setCases(new ArrayList<CaseData>(strEqInsns.size()));
            for (SwitchRegion.CaseInfo swCaseInfo : switchRegion.getCases()) {
                if (this.processCase(switchData, swCaseInfo)) continue;
                mth.addWarnComment("Failed to restore switch over string. Please report as a decompilation issue");
                return false;
            }
            if (!this.mergeWithCode(switchData)) {
                mth.addWarnComment("Failed to restore switch over string. Please report as a decompilation issue");
                return false;
            }
            IRegion parentRegion = switchRegion.getParent();
            SwitchRegion replaceRegion = new SwitchRegion(parentRegion, switchRegion.getHeader());
            for (SwitchRegion.CaseInfo caseInfo : switchData.getNewCases()) {
                replaceRegion.addCase(Collections.unmodifiableList(caseInfo.getKeys()), caseInfo.getContainer());
            }
            if (!parentRegion.replaceSubBlock(switchRegion, replaceRegion)) {
                mth.addWarnComment("Failed to restore switch over string. Please report as a decompilation issue");
                return false;
            }
            SwitchOverStringVisitor.markCodeForRemoval(switchData);
            swInsn.replaceArg(swInsn.getArg(0), strArg.duplicate());
            return true;
        }
        catch (Throwable e) {
            mth.addWarnComment("Failed to restore switch over string. Please report as a decompilation issue", e);
            return false;
        }
    }

    private static void markCodeForRemoval(SwitchData switchData) {
        MethodNode mth = switchData.getMth();
        try {
            RegisterArg numArg;
            switchData.getToRemove().forEach(i -> i.add(AFlag.REMOVE));
            SwitchRegion codeSwitch = switchData.getCodeSwitch();
            if (codeSwitch != null) {
                IRegion parentRegion = switchData.getSwitchRegion().getParent();
                parentRegion.getSubBlocks().remove(codeSwitch);
                codeSwitch.getHeader().add(AFlag.REMOVE);
            }
            if ((numArg = switchData.getNumArg()) != null) {
                for (SSAVar ssaVar : numArg.getSVar().getCodeVar().getSsaVars()) {
                    InsnNode assignInsn = ssaVar.getAssignInsn();
                    if (assignInsn != null) {
                        assignInsn.add(AFlag.REMOVE);
                    }
                    for (RegisterArg useArg : ssaVar.getUseList()) {
                        InsnNode parentInsn = useArg.getParentInsn();
                        if (parentInsn == null) continue;
                        parentInsn.add(AFlag.REMOVE);
                    }
                    mth.removeSVar(ssaVar);
                }
            }
            InsnRemover.removeAllMarked(mth);
        }
        catch (Throwable e) {
            mth.addWarnComment("Failed to clean up code after switch over string restore", e);
        }
    }

    private boolean mergeWithCode(SwitchData switchData) {
        List<CaseData> cases = switchData.getCases();
        RegisterArg numArg = null;
        int extracted = 0;
        for (CaseData caseData : cases) {
            Object constVal;
            Object numInsn;
            IContainer container = caseData.getCode();
            List<InsnNode> insns = RegionUtils.collectInsns(switchData.getMth(), container);
            insns.removeIf(i -> i.getType() == InsnType.BREAK);
            if (insns.size() != 1 || ((InsnNode)(numInsn = insns.get(0))).getArgsCount() != 1 || !((constVal = InsnUtils.getConstValueByArg(switchData.getMth().root(), ((InsnNode)numInsn).getArg(0))) instanceof LiteralArg)) continue;
            if (numArg == null) {
                numArg = ((InsnNode)numInsn).getResult();
            } else if (!numArg.sameCodeVar(((InsnNode)numInsn).getResult())) {
                return false;
            }
            int num = (int)((LiteralArg)constVal).getLiteral();
            caseData.setCodeNum(num);
            ++extracted;
        }
        if (extracted == 0) {
            return true;
        }
        if (extracted != cases.size()) {
            return false;
        }
        cases.sort(Comparator.comparingInt(CaseData::getCodeNum));
        IContainer nextContainer = RegionUtils.getNextContainer(switchData.getMth(), switchData.getSwitchRegion());
        if (!(nextContainer instanceof SwitchRegion)) {
            return false;
        }
        SwitchRegion codeSwitch = (SwitchRegion)nextContainer;
        InsnNode swInsn = BlockUtils.getLastInsnWithType(codeSwitch.getHeader(), InsnType.SWITCH);
        if (swInsn == null || !swInsn.getArg(0).isSameCodeVar(numArg)) {
            return false;
        }
        HashMap<Integer, CaseData> casesMap = new HashMap<Integer, CaseData>(cases.size());
        for (CaseData caseData : cases) {
            CaseData prev = casesMap.put(caseData.getCodeNum(), caseData);
            if (prev != null) {
                return false;
            }
            RegionUtils.visitBlocks(switchData.getMth(), caseData.getCode(), block -> switchData.getToRemove().add((IAttributeNode)block));
        }
        ArrayList<SwitchRegion.CaseInfo> newCases = new ArrayList<SwitchRegion.CaseInfo>();
        for (SwitchRegion.CaseInfo caseInfo : codeSwitch.getCases()) {
            SwitchRegion.CaseInfo newCase = null;
            for (Object key : caseInfo.getKeys()) {
                Integer intKey = this.unwrapIntKey(key);
                if (intKey != null) {
                    CaseData caseData = (CaseData)casesMap.remove(intKey);
                    if (caseData == null) {
                        return false;
                    }
                    if (newCase == null) {
                        ArrayList<Object> keys = new ArrayList<Object>(caseData.getStrValues());
                        newCase = new SwitchRegion.CaseInfo(keys, caseInfo.getContainer());
                        continue;
                    }
                    newCase.getKeys().addAll(caseData.getStrValues());
                    continue;
                }
                if (key == SwitchRegion.DEFAULT_CASE_KEY) {
                    Iterator iterator = casesMap.entrySet().iterator();
                    while (iterator.hasNext()) {
                        CaseData caseData = (CaseData)iterator.next().getValue();
                        if (newCase == null) {
                            ArrayList<Object> keys = new ArrayList<Object>(caseData.getStrValues());
                            newCase = new SwitchRegion.CaseInfo(keys, caseInfo.getContainer());
                        } else {
                            newCase.getKeys().addAll(caseData.getStrValues());
                        }
                        iterator.remove();
                    }
                    if (newCase == null) {
                        newCase = new SwitchRegion.CaseInfo(new ArrayList<Object>(), caseInfo.getContainer());
                    }
                    newCase.getKeys().add(SwitchRegion.DEFAULT_CASE_KEY);
                    continue;
                }
                return false;
            }
            newCases.add(newCase);
        }
        switchData.setCodeSwitch(codeSwitch);
        switchData.setNumArg(numArg);
        switchData.setNewCases(newCases);
        return true;
    }

    private Integer unwrapIntKey(Object key) {
        if (key instanceof Integer) {
            return (Integer)key;
        }
        if (key instanceof FieldNode) {
            EncodedValue encodedValue = (EncodedValue)((FieldNode)key).get(JadxAttrType.CONSTANT_VALUE);
            if (encodedValue != null && encodedValue.getType() == EncodedType.ENCODED_INT) {
                return (Integer)encodedValue.getValue();
            }
            return null;
        }
        return null;
    }

    private static Map<InsnNode, String> collectEqualsInsns(MethodNode mth, SSAVar strVar) {
        IdentityHashMap<InsnNode, String> map = new IdentityHashMap<InsnNode, String>(strVar.getUseCount() - 1);
        for (RegisterArg useReg : strVar.getUseList()) {
            InvokeNode inv;
            InsnNode parentInsn = useReg.getParentInsn();
            if (parentInsn == null || parentInsn.getType() != InsnType.INVOKE || !(inv = (InvokeNode)parentInsn).getCallMth().getRawFullId().equals("java.lang.String.equals(Ljava/lang/Object;)Z")) continue;
            InsnArg strArg = inv.getArg(1);
            Object strValue = InsnUtils.getConstValueByArg(mth.root(), strArg);
            if (!(strValue instanceof String)) continue;
            map.put(parentInsn, (String)strValue);
        }
        return map;
    }

    private boolean processCase(SwitchData switchData, SwitchRegion.CaseInfo caseInfo) {
        AtomicBoolean fail = new AtomicBoolean(false);
        RegionUtils.visitRegions(switchData.getMth(), caseInfo.getContainer(), region -> {
            if (fail.get()) {
                return false;
            }
            if (region instanceof IfRegion) {
                CaseData caseData = this.fillCaseData((IfRegion)region, switchData);
                if (caseData == null) {
                    fail.set(true);
                    return false;
                }
                switchData.getCases().add(caseData);
            }
            return true;
        });
        return !fail.get();
    }

    @Nullable
    private CaseData fillCaseData(IfRegion ifRegion, SwitchData switchData) {
        IfCondition condition = Objects.requireNonNull(ifRegion.getCondition());
        boolean neg = false;
        if (condition.getMode() == IfCondition.Mode.NOT) {
            condition = condition.getArgs().get(0);
            neg = true;
        }
        String str = null;
        if (condition.isCompare()) {
            IfNode ifInsn = condition.getCompare().getInsn();
            InsnArg firstArg = ifInsn.getArg(0);
            if (firstArg.isInsnWrap()) {
                str = switchData.getStrEqInsns().get(((InsnWrapArg)firstArg).getWrapInsn());
            }
            if (ifInsn.getOp() == IfOp.NE && ifInsn.getArg(1).isTrue()) {
                neg = true;
            }
            if (ifInsn.getOp() == IfOp.EQ && ifInsn.getArg(1).isFalse()) {
                neg = true;
            }
            if (str != null) {
                switchData.getToRemove().add(ifInsn);
                switchData.getToRemove().addAll(ifRegion.getConditionBlocks());
            }
        }
        if (str == null) {
            return null;
        }
        CaseData caseData = new CaseData();
        caseData.getStrValues().add(str);
        caseData.setCode(neg ? ifRegion.getElseRegion() : ifRegion.getThenRegion());
        return caseData;
    }

    @Nullable
    private RegisterArg getStrHashCodeArg(InsnArg arg) {
        if (arg.isRegister()) {
            return this.getStrFromInsn(((RegisterArg)arg).getAssignInsn());
        }
        if (arg.isInsnWrap()) {
            return this.getStrFromInsn(((InsnWrapArg)arg).getWrapInsn());
        }
        return null;
    }

    @Nullable
    private RegisterArg getStrFromInsn(@Nullable InsnNode insn) {
        if (insn == null || insn.getType() != InsnType.INVOKE) {
            return null;
        }
        InvokeNode invInsn = (InvokeNode)insn;
        MethodInfo callMth = invInsn.getCallMth();
        if (!callMth.getRawFullId().equals("java.lang.String.hashCode()I")) {
            return null;
        }
        InsnArg arg = invInsn.getInstanceArg();
        if (arg == null || !arg.isRegister()) {
            return null;
        }
        return (RegisterArg)arg;
    }

    private static final class CaseData {
        private final List<String> strValues = new ArrayList<String>();
        private IContainer code = null;
        private int codeNum = -1;

        private CaseData() {
        }

        public List<String> getStrValues() {
            return this.strValues;
        }

        public IContainer getCode() {
            return this.code;
        }

        public void setCode(IContainer code) {
            this.code = code;
        }

        public int getCodeNum() {
            return this.codeNum;
        }

        public void setCodeNum(int codeNum) {
            this.codeNum = codeNum;
        }

        public String toString() {
            return "CaseData{" + this.strValues + "}";
        }
    }

    private static final class SwitchData {
        private final MethodNode mth;
        private final SwitchRegion switchRegion;
        private final List<IAttributeNode> toRemove = new ArrayList<IAttributeNode>();
        private Map<InsnNode, String> strEqInsns;
        private List<CaseData> cases;
        private List<SwitchRegion.CaseInfo> newCases;
        private SwitchRegion codeSwitch;
        private RegisterArg numArg;

        private SwitchData(MethodNode mth, SwitchRegion switchRegion) {
            this.mth = mth;
            this.switchRegion = switchRegion;
        }

        public List<CaseData> getCases() {
            return this.cases;
        }

        public void setCases(List<CaseData> cases) {
            this.cases = cases;
        }

        public List<SwitchRegion.CaseInfo> getNewCases() {
            return this.newCases;
        }

        public void setNewCases(List<SwitchRegion.CaseInfo> cases) {
            this.newCases = cases;
        }

        public MethodNode getMth() {
            return this.mth;
        }

        public Map<InsnNode, String> getStrEqInsns() {
            return this.strEqInsns;
        }

        public void setStrEqInsns(Map<InsnNode, String> strEqInsns) {
            this.strEqInsns = strEqInsns;
        }

        public SwitchRegion getSwitchRegion() {
            return this.switchRegion;
        }

        public List<IAttributeNode> getToRemove() {
            return this.toRemove;
        }

        public SwitchRegion getCodeSwitch() {
            return this.codeSwitch;
        }

        public void setCodeSwitch(SwitchRegion codeSwitch) {
            this.codeSwitch = codeSwitch;
        }

        public RegisterArg getNumArg() {
            return this.numArg;
        }

        public void setNumArg(RegisterArg numArg) {
            this.numArg = numArg;
        }
    }
}

