
import pymel.core as core
import pymel.api as api
import maya.mel as mel
import maya.cmds as cmds

from .general import *
from .network import *
from .mirror import *

def kinematicSwitchByType(network, typ):
    """
    Perform kinematic switch based on the type of the network.

    Args:
        network (str): Name of the network.
        typ (str): Type of the network, determines the kinematic switch action.
    """
    if typ in ["arm", "wing", "dragonWing", "arm_1", "arm_2", "arm_3", "arm_4"]:
        kinematicSwitch(network.get(), "limb")

    elif typ == "head":
        kinematicSwitch(network.get(), "head")

    elif typ == "finger":
        kinematicSwitch(network.get(), "finger")

    elif typ in ["leg", "frontLeg", "midLeg", "backLeg", "leg_1", "leg_2", "leg_3", "leg_4"]:
        options = network.getAttr("options").node()
        namespace = options.namespace()
        side = options.stripNamespace().split("_")[0]    

        kinematicSwitch(network.get(), "limb", 0)
        kinematicSwitch(network.get(), "toe", 0)
        if options.ikfk.get():
            resetControl(namespace + side + "_" + typ + "_footroll_control")
        kinematicSwitch(network.get(), "toe", 1)

    elif typ in ["arm_hier", "leg_hier", "head_hier"]:
        core.evaluatorCommand(ev=network.getAttr("evaluator"), c="ikfk_switch")

    elif typ == "finger_hier":
        idx = int(network.getAttr("blockId"))
        core.evaluatorCommand(ev=network.getAttr("evaluator"), c="ikfk_switch", ai=("block", idx))

    elif typ == "biped_rig_arm":
        side = network.get().stripNamespace()[0]  # first character
        core.evaluatorCommand(ev=network.getAttr("evaluator"), c=side + "_arm_ikfk_switch")

    elif typ == "biped_rig_leg":
        side = network.get().stripNamespace()[0]  # first character
        core.evaluatorCommand(ev=network.getAttr("evaluator"), c=side + "_leg_ikfk_switch")

    elif typ == "biped_rig_head":
        core.evaluatorCommand(ev=network.getAttr("evaluator"), c="head_ikfk_switch")

    elif typ == "biped_rig_fingers":
        idx = int(network.getAttr("blockId"))
        side = network.get().stripNamespace()[0]  # first character
        core.evaluatorCommand(ev=network.getAttr("evaluator"), c=side + "_fingers_ikfk_switch", ai=("block", idx))

def kinematicSwitchMulti(object):
    """
    Perform kinematic switch on selected controls and the current menu control.

    Args:
        object (str): Object name or PyNode for which kinematic switch is performed.
    """
    def getFrameSpecList(ac):
        """
        Get frame and value specifications from the animation curve.

        Args:
            ac (str): Animation curve node.

        Returns:
            list: List of tuples containing frame and value specifications.
        """
        specList = []
        for i in range(core.keyframe(ac, q=True, kc=True)):
            t = core.keyframe(ac, index=i, q=True)
            v = core.keyframe(ac, index=i, q=True, vc=True)
            if t and v:
                specList.append((t[0], v[0]))
                
        return specList

    def getValueFromFrameSpecList(specList, tm):
        """
        Get value from the frame specification list at a given time.

        Args:
            specList (list): List of tuples containing frame and value specifications.
            tm (float): Time value.

        Returns:
            float: Value corresponding to the specified time.
        """
        _, prev_v = specList[0]
        for i, (t, v) in enumerate(specList):
            if t > tm:
                return prev_v
            prev_v = v
        return v

    nodes = set(core.ls(sl=True) + [core.PyNode(object)])

    mayaPlayBackSlider = mel.eval('$tmpVar=$gPlayBackSlider')
    selectedRange = cmds.timeControl(mayaPlayBackSlider, q=True, ra=True)
    isRangeSelected = selectedRange[1] - selectedRange[0] > 1

    animCurves = []
    if isRangeSelected:
        for obj in nodes:
            animCurves += obj.tx.listConnections(s=True, d=False, type="animCurve")
            animCurves += obj.ty.listConnections(s=True, d=False, type="animCurve")
            animCurves += obj.tz.listConnections(s=True, d=False, type="animCurve")
            animCurves += obj.rx.listConnections(s=True, d=False, type="animCurve")
            animCurves += obj.ry.listConnections(s=True, d=False, type="animCurve")
            animCurves += obj.rz.listConnections(s=True, d=False, type="animCurve")
    
    skip = []
    for obj in nodes:
        network = getNetwork(obj)
        if not network or network.get() in skip:  # do not switch more than once per network
            continue
        
        typ = network.getAttr("type")

        if isRangeSelected:  # selected keys
            kinematic = network.getAttr("kinematic")
            if not kinematic:
                continue
            
            kinematic_animCurves = kinematic.listConnections(s=True, d=False, type="animCurve")
            frameSpecList = getFrameSpecList(kinematic_animCurves[0]) if kinematic_animCurves else [(0, 0)]
            
            skipFrames = []
            for ac in animCurves + kinematic_animCurves:
                for i in range(core.keyframe(ac, q=True, kc=True)):
                    time = core.keyframe(ac, index=i, q=True)
                    if time and time[0] not in skipFrames and time[0] >= selectedRange[0] and time[0] <= selectedRange[1]:
                        core.currentTime(time[0])  # go to specific frame and make kinematic switch
                        kinematic.set(getValueFromFrameSpecList(frameSpecList, time[0]))
                        
                        kinematicSwitchByType(network, typ)
                        
                        skipFrames.append(time[0])
                        skip.append(network.get())
        
        else:
            kinematicSwitchByType(network, typ)
            skip.append(network.get())

def kinematicSwitch(network, typ="limb", doSetAttr=True):
    """
    Perform kinematic switch on a network.

    Args:
        network (str): Network name.
        typ (str): Type of switch ("limb", "toe", "head", "finger").
        doSetAttr (bool): Whether to set the kinematic attribute at the end. Defaults to True.
    """
    network = Network(network)

    kinematic = network.getAttr("kinematic")

    if kinematic.get() == 1:  # move IK to FK

        if typ == "limb":
            ik = network.getAttr("ik").node()
            elbow = network.getAttr("polevector").node()

            ik_seamless = network.getAttr("ik_seamless").node()
            elbow_seamless = network.getAttr("polevector_seamless").node()

            ik.t.set(ik_seamless.t.get())
            ik.r.set(ik_seamless.r.get())

            elbow.t.set(elbow_seamless.t.get())

        elif typ == "toe":
            toe_ik = network.getAttr("toe_ik").node()
            toe_ik_seamless = network.getAttr("toe_ik_seamless").node()
            if toe_ik.rx.isSettable(): toe_ik.rx.set(toe_ik_seamless.rx.get())
            if toe_ik.ry.isSettable(): toe_ik.ry.set(toe_ik_seamless.ry.get())
            if toe_ik.rz.isSettable(): toe_ik.rz.set(toe_ik_seamless.rz.get())

        elif typ == "head":
            ik = network.getAttr("ik").node()
            ik_seamless = network.getAttr("ik_seamless").node()

            ik.t.set(ik_seamless.t.get())
            ik.r.set(ik_seamless.r.get())

        elif typ == "finger":
            fingerKinematicSwitch(network)

        else:
            core.error("kinematicSwitch: typ must be 'limb', 'toe', 'head', 'finger'")

        if doSetAttr:
            kinematic.set(0)
    else:  # move FK to IK

        if typ == "limb":
            fk1 = network.getAttr("fk1").node()
            fk2 = network.getAttr("fk2").node()
            fk3 = network.getAttr("fk3").node()

            fk1_seamless = network.getAttr("fk1_seamless").node()
            fk2_seamless = network.getAttr("fk2_seamless").node()
            fk3_seamless = network.getAttr("fk3_seamless").node()

            safeCopyAttrs(fk1_seamless, fk1, ["rx", "ry", "rz"])
            safeCopyAttrs(fk2_seamless, fk2, ["rx", "ry", "rz"])
            safeCopyAttrs(fk3_seamless, fk3, ["rx", "ry", "rz"])

            # stretch
            options = network.getAttr("options")
            j1 = network.getAttr("ik_1_joint")
            j2 = network.getAttr("ik_2_joint")
            if j1 and j2:
                sx1 = j1.node().sx.get()
                sx2 = j2.node().sx.get()
                options.node().scaleCoeff1X.set(sx1)
                options.node().scaleCoeff2X.set(sx2)
            
        elif typ == "toe":
            toe_fk = network.getAttr("toe_fk").node()
            toe_fk_seamless = network.getAttr("toe_fk_seamless").node()
            if toe_fk.rx.isSettable(): toe_fk.rx.set(toe_fk_seamless.rx.get())
            if toe_fk.ry.isSettable(): toe_fk.ry.set(toe_fk_seamless.ry.get())
            if toe_fk.rz.isSettable(): toe_fk.rz.set(toe_fk_seamless.rz.get())

        elif typ == "head":
            i = 1
            while network.getAttr("fk" + str(i)):
                fk_n = network.getAttr("fk" + str(i)).node()
                fk_n_seamless = network.getAttr("fk" + str(i) + "_seamless").node()

                if fk_n.rx.isSettable(): fk_n.rx.set(fk_n_seamless.rx.get())
                if fk_n.ry.isSettable(): fk_n.ry.set(fk_n_seamless.ry.get())
                if fk_n.rz.isSettable(): fk_n.rz.set(fk_n_seamless.rz.get())

                i += 1

            fk = network.getAttr("fk").node()
            fk_seamless = network.getAttr("fk_seamless").node()
            fk.r.set(fk_seamless.r.get())

        elif typ == "finger":
            fingerKinematicSwitch(network)

        else:
            core.error("kinematicSwitch: typ must be 'limb', 'toe', 'head', 'finger'")

        if doSetAttr:
            kinematic.set(1)

def fingerKinematicSwitch(network):
    """
    IK/FK seamless kinematic switch for IK/FK fingers.

    Args:
        network (Network): Finger's network.
    """
    if network.getAttr("type") != "finger":
        return

    kinematic = network.getAttr("kinematic")
    ik_control = network.getAttr("ik_control").node()
    fk_control = network.getAttr("fk_control").node()

    if kinematic.get() < 0.5:  # move FK to IK
        ik1 = network.getAttr("ik1").node()
        ik2 = network.getAttr("ik2").node()
        ik3 = network.getAttr("ik3").node()

        fk_control.rotate.set(ik1.r.get())
        fk_control.rotateA.set(ik2.rz.get())
        fk_control.rotateB.set(ik3.rz.get())

        kinematic.set(1)

    else:  # move IK to FK
        fk3 = network.getAttr("fk3").node()
        fk4 = network.getAttr("fk4").node()
        ik3 = network.getAttr("ik3").node()
        ik4 = network.getAttr("ik4").node()
        sign = network.getAttr("normalSign")

        rotator = network.getAttr("rotator").node()

        ik_control.tip.set(0)
        ik_control.twist.set(0)

        RAD2DEG = 57.2958

        normal = fk3.getTranslation("world") - fk_control.getTranslation("world")
        normal.normalize()

        normal *= sign

        ik_control.setTranslation(fk4.getTranslation("world"), "world")

        rotatorMatrix = rotator.worldMatrix.get()
        polevec = core.dt.Vector(rotatorMatrix.a10, rotatorMatrix.a11, rotatorMatrix.a12)

        fkMatrix = fk_control.worldMatrix.get()
        newPolevec = core.dt.Vector(fkMatrix.a10, fkMatrix.a11, fkMatrix.a12)

        polevec = polevec - (polevec * normal) * normal
        newPolevec = newPolevec - (newPolevec * normal) * normal

        q = api.MQuaternion(polevec, newPolevec)

        axis = api.MVector()
        angle = api.doublePtr()
        q.getAxisAngle(axis, angle)
        twist = angle.value() * RAD2DEG  # to degrees

        if axis * normal < 0:
            twist *= -1

        vecIK = ik3.getTranslation("world") - ik4.getTranslation("world")
        vecFK = fk3.getTranslation("world") - fk4.getTranslation("world")

        q = api.MQuaternion(vecIK, vecFK)
        axis = api.MVector()
        angle = api.doublePtr()
        q.getAxisAngle(axis, angle)
        tip = angle.value() * RAD2DEG  # to degrees
        tip *= sign

        ik_control.tip.set(tip)
        ik_control.twist.set(twist)

        kinematic.set(0)
