from pxr.UsdQt import _usdQt as UsdQt
import pymel.internal.factories
from pxr import Tf

def _edit_target_changed_callback(notice, stage):
    if stage:
        UsdQt.UndoRouter.TrackLayer(stage.GetEditTarget().GetLayer())

def _undo_stack_notice_callback(notice, sender):
    inverse = UsdQt.UndoInverse()
    UsdQt.UndoRouter.TransferEdits(inverse)
    cmd = UsdQtUndoCommand(inverse)
    undo_stack = pymel.internal.factories.apiUndo
    #sometimes pymel code could fail witch checking cmds.undoInfo state
    #so we wrap in try except block as workaround
    try:
        undo_stack.append(cmd)
    except:
        pass


from pxr.UsdQt._usdQt import UndoBlock

class UsdQtUndoCommand(object):
    def __init__(self, usdqt_block):
        super(UsdQtUndoCommand, self).__init__()
        self.usdqt_block = usdqt_block
        self.first_time = True
    
    def redoIt(self):
        if self.first_time is False:
            self.usdqt_block.Invert()

    def undoIt(self):
        self.usdqt_block.Invert()
        self.first_time = False 

UndoListener = None

def init_undo_for_stage(stage):
    UsdQt.UndoRouter.TrackLayer(stage.GetEditTarget().GetLayer())

def init_usd_qt_undo():
    global UndoListener
    if UndoListener is None:
        UndoListener = Tf.Notice.RegisterGlobally(UsdQt.UndoStackNotice, _undo_stack_notice_callback)