
import pymel.core as core
from .network import *

def copyAttrs(source, dest, exchange=False, reverseDict={}):
    """
    Copy attributes from one node to another.

    Args:
        source (core.PyNode): Source node.
        dest (core.PyNode): Destination node.
        exchange (bool, optional): Whether to exchange attributes bidirectionally. Defaults to False.
        reverseDict (dict, optional): Dictionary to reverse the attribute values. Defaults to {}.
    """
    if not core.objExists(source) and not core.objExists(dest):
        return

    source = core.PyNode(source)
    dest = core.PyNode(dest)

    for a in source.listAttr(k=True): # iterate over all source attributes
        attrName = a.shortName()
        if dest.hasAttr(attrName):
            if dest.attr(attrName).isSettable():
                if exchange:
                    tmp = dest.attr(attrName).get()
                    dest.attr(attrName).set(a.get() * reverseDict.get(attrName, 1.0))
                    a.set(tmp)
                else:
                    dest.attr(attrName).set(a.get() * reverseDict.get(attrName, 1.0))

def safeCopyAttrs(source, dest, attrs, exchange=False):
    """
    Safely copy selected attributes from source to destination.

    Args:
        source (core.PyNode): Source node.
        dest (core.PyNode): Destination node.
        attrs (list of str): List of attribute names.
        exchange (bool, optional): Whether to exchange attributes bidirectionally. Defaults to False.
    """
    source = core.PyNode(source)
    dest = core.PyNode(dest)

    for a in attrs:
        if source.hasAttr(a) and dest.hasAttr(a):
            if dest.attr(a).isSettable():
                if exchange:
                    tmp = dest.attr(a).get()
                    dest.attr(a).set(source.attr(a).get()) # copy
                    source.attr(a).set(tmp)
                else:
                    dest.attr(a).set(source.attr(a).get()) # copy

def mirrorControl(ctrl, mirrorCtrl, centerObject, mirrorAxis=(1, 1, 1), mirrorScale=(1, 1, 1), exchange=False):
    """
    Mirror controls around a chosen axis.

    Args:
        ctrl (core.nt.Transform): Control to mirror from.
        mirrorCtrl (core.nt.Transform): Control to mirror to.
        centerObject (core.nt.Transform): Center object for mirroring.
        mirrorAxis (list of 3 floats): Axis of symmetry (mirroring).
        mirrorScale (list of 3 floats): Scaling for symmetry.
        exchange (bool, optional): Whether to exchange attributes bidirectionally. Defaults to False.
    """
    ctrl = core.PyNode(ctrl)
    mirrorCtrl = core.PyNode(mirrorCtrl)
    centerObject = core.PyNode(centerObject)

    oldSel = []
    if exchange and ctrl != mirrorCtrl:
        oldSel = core.ls(sl=True)
        tmp = core.createNode("transform")
        core.delete(core.parentConstraint(mirrorCtrl, tmp))

    namespace = mirrorCtrl.namespace()
    sm = core.dt.TransformationMatrix(core.PyNode(namespace + "main_control").matrix.get()).asScaleMatrix()

    m = ctrl.worldMatrix.get() * centerObject.worldInverseMatrix.get() * sm

    cmFn = core.dt.TransformationMatrix(centerObject.worldMatrix.get())
    cmFn.setScale(mirrorAxis, "object")

    m *= cmFn.asMatrix()
    m *= mirrorCtrl.parentInverseMatrix.get()

    mFn = core.dt.TransformationMatrix(m)
    mFn.setScale(mirrorScale, "object")

    if hasattr(mirrorCtrl, "getOrientation"): # for bones
        jo = mirrorCtrl.getOrientation().asMatrix()
        rm = mFn.asRotateMatrix() * jo.inverse()

        tmFn = core.dt.TransformationMatrix(rm)
        tmFn.setTranslation(mFn.getTranslation("object"), "object")
        mirrorCtrl.setTransformation(tmFn.asMatrix())
    else:
        mirrorCtrl.setTransformation(mFn.asMatrix())

    if exchange and ctrl != mirrorCtrl:
        mirrorControl(tmp, ctrl, centerObject, mirrorAxis, mirrorScale, False)
        core.delete(tmp)
        if oldSel:
            core.select(oldSel)

def mirrorBend(namespace, fromSide, toSide, typ, centerObject=None):
    """
    Mirror bends.

    Args:
        namespace (str): Character namespace (controls).
        fromSide (str): From where ("L" or "R").
        toSide (str): To where ("L" or "R").
        typ (str): "arm" or "leg".
        centerObject (core.nt.Transform, optional): Center object for mirroring. Defaults to None.
    """
    stretchControls = core.ls(namespace + fromSide + "_" + typ + "_*_stretchRig_*_control")
    bendControls = core.ls(namespace + fromSide + "_" + typ + "_bend_*_control") + core.ls(namespace + fromSide + "_" + typ + "_bend_control")

    for fromEach in stretchControls + bendControls:
        names = fromEach.stripNamespace().split("_")
        names[0] = toSide
        toEach = namespace + "_".join(names)

        if core.objExists(fromEach) and core.objExists(toEach):
            if centerObject:
                mirrorControl(fromEach, namespace + toEach, namespace + centerObject, [-1, 1, 1], [-1, -1, -1])
            else:
                fromNode = core.PyNode(fromEach)
                toNode = core.PyNode(toEach)
                
                for attrName in ["tx", "ty", "tz"]:
                    toNode.attr(attrName).set(fromNode.attr(attrName).get() * (-1))

def mirrorArm(network, mirrorNetwork, namespace, side, mirrorSide, typ, exchange):
    """
    Mirror arms or legs.

    Args:
        network (Network): Network object.
        mirrorNetwork (Network): Mirrored network object.
        namespace (str): Character namespace.
        side (str): Side ("L" or "R").
        mirrorSide (str): Mirrored side ("L" or "R").
        typ (str): "arm", "frontLeg", "leg", etc.
        exchange (bool): Whether to exchange attributes bidirectionally.
    """
    centerObject = "M_spine_fk_1_control"  # if not exchange else "main_control"

    mirrorControl(network.getAttr("ik").node(), mirrorNetwork.getAttr("ik").node(), namespace + centerObject,
                  [-1, 1, 1], [-1, -1 if typ in ["arm", "arm_hier", "biped_rig_arm"] else 1, 1], exchange)
    mirrorControl(network.getAttr("polevector").node(), mirrorNetwork.getAttr("polevector").node(),
                  namespace + centerObject, [-1, 1, 1], [1, 1, 1], exchange)

    mirrorControl(network.getAttr("fk1").node(), mirrorNetwork.getAttr("fk1").node(), namespace + centerObject,
                  [-1, 1, 1], [-1, -1, -1], exchange)
    copyAttrs(network.getAttr("fk2").node(), mirrorNetwork.getAttr("fk2").node(), exchange)
    copyAttrs(network.getAttr("fk3").node(), mirrorNetwork.getAttr("fk3").node(), exchange)

    l_shoulder_control = namespace + side + "_" + typ + "_shoulder_control"
    r_shoulder_control = namespace + mirrorSide + "_" + typ + "_shoulder_control"

    if core.objExists(l_shoulder_control) and core.objExists(r_shoulder_control):
        copyAttrs(l_shoulder_control, r_shoulder_control, exchange, reverseDict={"tx": -1, "ty": -1, "tz": -1})

    # footroll control
    if network.getAttr("footroll"):
        copyAttrs(network.getAttr("footroll").node(), mirrorNetwork.getAttr("footroll").node(), exchange,
                  reverseDict={"rz": -1})
    else:
        copyAttrs(namespace + side + "_" + typ + "_footroll_control",
                  namespace + mirrorSide + "_" + typ + "_footroll_control", exchange, reverseDict={"rz": -1})

    copyAttrs(network.getAttr("kinematic").node(), mirrorNetwork.getAttr("kinematic").node(), exchange)
    mirrorBend(namespace, side, mirrorSide, typ)


def mirrorDragonWing(network, mirrorNetwork, namespace, side, mirrorSide, typ, exchange):
    """
    Mirror dragon wings.

    Args:
        network (Network): Network object.
        mirrorNetwork (Network): Mirrored network object.
        namespace (str): Character namespace.
        side (str): Side ("L" or "R").
        mirrorSide (str): Mirrored side ("L" or "R").
        typ (str): "dragonWing".
        exchange (bool): Whether to exchange attributes bidirectionally.
    """
    mirrorArm(network, mirrorNetwork, namespace, side, mirrorSide, "arm", exchange)

    attrNames = [
        "wing_index", "wing_index_1", "wing_index_2", "wing_index_3",
        "wing_middle", "wing_middle_1", "wing_middle_2", "wing_middle_3",
        "wing_ring", "wing_ring_1", "wing_ring_2", "wing_ring_3",
        "wing_pinky", "wing_pinky_1", "wing_pinky_2",          
        "wing_pinky_3",
        "wing_elbow", "wing_elbow_1", "wing_elbow_2", "wing_elbow_3",
        "membrane_1", "membrane_2", "membrane_3", "membrane_4", "membrane_5", "membrane_6"
    ]

    for name in attrNames:
        if network.get() and network.get().hasAttr(name) and mirrorNetwork.get() and mirrorNetwork.get().hasAttr(name):
            copyAttrs(network.getAttr(name).node(), mirrorNetwork.getAttr(name).node(), exchange)

def mirrorByNetwork(network, exchange=False):
    """
    Mirror based on a network.

    Args:
        network (str): Network name.
        exchange (bool, optional): Whether to exchange attributes bidirectionally. Defaults to False.
    """
    network = Network(network)

    typ = network.getAttr("type")
    namespace = network.get().namespace()
    side = network.get().stripNamespace().split("_")[0]

    mirrorSide = "R" if side == "L" else ("L" if side == "R" else "M")
    mirrorNetwork = Network(namespace + mirrorSide + "_" + "_".join(network.get().stripNamespace().split("_")[1:]))

    if typ in ["arm", "frontLeg", "arm_hier", "biped_rig_arm"]: # mirror arms
        mirrorArm(network, mirrorNetwork, namespace, side, mirrorSide, typ, exchange)

    elif typ in ["leg", "backLeg","midLeg","leg_hier", "biped_rig_leg"]: # mirror legs
        centerObject = "M_spine_fk_1_control"

        mirrorControl(network.getAttr("ik").node(), mirrorNetwork.getAttr("ik").node(), namespace + centerObject, [-1, 1, 1], [-1, 1, 1], exchange)
        mirrorControl(network.getAttr("polevector").node(), mirrorNetwork.getAttr("polevector").node(), namespace + centerObject, [-1, 1, 1], [1, 1, 1], exchange)

        mirrorControl(network.getAttr("fk1").node(), mirrorNetwork.getAttr("fk1").node(), namespace + centerObject, [-1, 1, 1], [-1, -1, -1], exchange)
        copyAttrs(network.getAttr("fk2").node(), mirrorNetwork.getAttr("fk2").node(), exchange)
        copyAttrs(network.getAttr("fk3").node(), mirrorNetwork.getAttr("fk3").node(), exchange)

        copyAttrs(network.getAttr("toe_ik").node(), mirrorNetwork.getAttr("toe_ik").node(), exchange)
        copyAttrs(network.getAttr("toe_fk").node(), mirrorNetwork.getAttr("toe_fk").node(), exchange)

        # back shoulder control
        l_shoulder_control = namespace + side + "_" + typ + "_shoulder_control"
        r_shoulder_control = namespace + mirrorSide + "_" + typ + "_shoulder_control"

        if core.objExists(l_shoulder_control) and core.objExists(r_shoulder_control):
            copyAttrs(l_shoulder_control, r_shoulder_control, exchange, reverseDict={"tx": -1, "ty": -1, "tz": -1})

        # footroll control
        if network.getAttr("footroll"):
            copyAttrs(network.getAttr("footroll").node(), mirrorNetwork.getAttr("footroll").node(), exchange, reverseDict={"rz": -1})
        else:            
            copyAttrs(namespace + side + "_" + typ + "_footroll_control", namespace + mirrorSide + "_" + typ + "_footroll_control", exchange, reverseDict={"rz": -1})

        copyAttrs(network.getAttr("kinematic").node(), mirrorNetwork.getAttr("kinematic").node(), exchange)
        mirrorBend(namespace, side, mirrorSide, typ)

    elif typ == "spine": # mirror spine
        for fix in core.ls(namespace + "M_spine_fix_*_control", type="transform")[::-1]:
            mirrorControl(fix, fix, fix + "_null", (-1, 1, 1), (-1, 1, 1), False)

        for fk in core.ls(namespace + "M_spine_fk_*_control", type="transform")[::-1]:
            mirrorControl(fk, fk, fk + "_null", (-1, 1, 1), (-1, 1, 1), False)

        for ik in core.ls(namespace + "M_spine_ik_*_control", type="transform")[::-1]:
            mirrorControl(ik, ik, ik + "_null", (-1, 1, 1), (-1, 1, 1), False)

    elif typ in ["head", "head_hier", "biped_rig_head"]: # mirror head and neck
        if core.objExists(namespace + "M_head_ik_middle_control"):
            mirrorControl(namespace + "M_head_ik_middle_control", namespace + "M_head_ik_middle_control", namespace + "M_head_ik_middle_control_null", (-1, 1, 1), (-1, 1, 1), False)

        centerObject = "M_spine_fk_1_control"
        mirrorControl(namespace + "M_head_fk_control", namespace + "M_head_fk_control", namespace + centerObject, (-1, 1, 1), (-1, 1, 1), False)
        mirrorControl(namespace + "M_head_ik_control", namespace + "M_head_ik_control", namespace + centerObject, (-1, 1, 1), (-1, 1, 1), False)

        for neck_fk_ctrl in core.ls(namespace + "M_neck_fk_*_control", type="transform")[::-1]:
            mirrorControl(neck_fk_ctrl, neck_fk_ctrl, neck_fk_ctrl + "_null", (-1, 1, 1), (-1, 1, 1), False)

    elif typ == "fingers":
        ctrl = network.getAttr("fingers").node()
        mirror_ctrl = mirrorNetwork.getAttr("fingers").node()
        copyAttrs(ctrl, mirror_ctrl, exchange)

    elif typ == "dragonWing":
        mirrorDragonWing(network, mirrorNetwork, namespace, side, mirrorSide, typ, exchange)
