from pxr.UsdQtEditors._Qt import QtCore, QtWidgets, QtGui
from pxr import Sdf, Tf, Usd, UsdGeom, UsdUtils
from pxr.UsdQt.hierarchyModel import HierarchyStandardModel as UsdQtModel, \
    HierarchyBaseModel, HierarchyStandardFilterModel

from .i18n import i18n
from .modelColumn import NameColumn, VisibilityColumn, MetaDataColumn, VariantColumn, MaterialColumn, \
    ColumnController, IsOnEditTargetRole

import six
from six.moves import range

class HierarchyModel(UsdQtModel):
    """Configurable model for displaying hierarchies"""
    ArcsIconPath = None
    NoarcsIconPath = None

    highlighted = []

    # after changing visibility, need to update maya stage immediately
    needUpdateStage = QtCore.Signal()

    # when node in maya selected,
    # need to repaint it's parents in outliner and highlight them
    repaintView = QtCore.Signal()

    needSelectIndexes = QtCore.Signal(list)

    def __init__(self, stage=None, parent=None, show_all_flag=True):
        """
        :param stage:
        :param proxy:
        :param parent: QtWidget parent
        :param show_all_flag: if False - don't expand lower then kind = component
        """
        UsdQtModel.__init__(self,
            stage, Usd.TraverseInstanceProxies(
                Usd.PrimIsDefined | ~Usd.PrimIsDefined), parent)

        self._time_code = Usd.TimeCode.Default()
        self._show_pcp_errors = False
        self.show_all_flag = show_all_flag
        self.default_columns()

    def ResetStage(self, stage):
        UsdQtModel.ResetStage(self, stage)
        if self._IsStageValid():
            self._listener_stage = Tf.Notice.Register(
                Usd.Notice.StageNotice, self._edit_target_changed, self._stage)

    def _Invalidate(self):
        self._stage = None
        self._index = None
        self._listener = None
        self._listener_stage = None

    def _edit_target_changed(self, notice, sender):
        if type(notice) == Usd.Notice.StageEditTargetChanged:
            self.repaintView.emit()

    def change_show_pcp_errors_flag(self, state):
        self._show_pcp_errors = state
        self.repaintView.emit()

    def default_columns(self):
        self._columns = ColumnController()
        # columns_ = ["Vis", "Name", "Type", "Kind", "Full material", "Preview material"]
        self._columns.add_column(VisibilityColumn, "Vis", i18n("outliner.hierarchy_model", "Vis"))
        self._columns.add_column(NameColumn, "Name", i18n("outliner.hierarchy_model", "Name"))
        type_column = self._columns.add_column(MetaDataColumn, "Type", i18n("outliner.hierarchy_model", "Type")) #, 'typeName')
        type_column.set_visibility(False)
        kind_column = self._columns.add_column(MetaDataColumn, "Kind", i18n("outliner.hierarchy_model", "Kind"))#, 'kind')
        kind_column.set_visibility(False)
        from pxr import UsdShade
        matetial_column = self._columns.add_column(MaterialColumn, "Full material", i18n("outliner.hierarchy_model", "Full material"), UsdShade.Tokens.full)
        matetial_column.set_visibility(False)
        matetial_preview_column = self._columns.add_column(MaterialColumn, "Preview material", i18n("outliner.hierarchy_model", "Preview material"), UsdShade.Tokens.preview)
        matetial_preview_column.set_visibility(False)

        reg_variant_sets = UsdUtils.GetRegisteredVariantSets()
        for var_set in reg_variant_sets:
            variant_column = self._columns.add_column(VariantColumn, var_set.name, var_set.name)
            variant_column.set_visibility(False)

        self._columns.set_tree_index(self._columns.index(self._columns.name.header_name))

    def columnCount(self, parent):
        return self._columns.count()

    def _OnObjectsChanged(self, notice, sender):
        resynced_paths = notice.GetResyncedPaths()
        update_paths = set(resynced_paths)

        root_proxy = self._index.GetRoot()
        if root_proxy.GetPrim().GetPath() in resynced_paths:
            for i in range(self._index.GetChildCount(root_proxy)):
                proxy = self._index.GetChild(root_proxy, i)
                update_paths.add(proxy.GetPrim().GetPath())

        if len(update_paths) > 0:
            with self.LayoutChangedContext(self):
                persistentIndices = self.persistentIndexList()
                indexToPath = {}
                for index in persistentIndices:
                    indexProxy = index.internalPointer()
                    indexPrim = indexProxy.GetPrim()
                    indexPath = indexPrim.GetPath()

                    for resyncedPath in update_paths:
                        commonPath = resyncedPath.GetCommonPrefix(indexPath)
                        # if the paths are siblings or if the
                        # index path is a child of resynced path, you need to
                        # update any persistent indices
                        areSiblings = (commonPath == resyncedPath.GetParentPath()
                                       and commonPath != indexPath)
                        indexIsChild = (commonPath == resyncedPath)

                        if areSiblings or indexIsChild:
                            indexToPath[index] = indexPath

                self._index.ResyncSubtrees(update_paths)

                fromIndices = []
                toIndices = []

                old_paths = update_paths.intersection(indexToPath.values())
                new_paths = update_paths.difference(old_paths)

                for index in indexToPath:
                    path = indexToPath[index]
                    # in path exist in old persistent indexes and still exist in new model, just update
                    if self._index.ContainsPath(path):
                        newProxy = self._index.GetProxy(path)
                        newRow = self._index.GetRow(newProxy)

                        for i in range(self.columnCount(QtCore.QModelIndex())):
                            fromIndices.append(index)
                            toIndices.append(self.createIndex(
                                newRow, index.column(), newProxy))

                    # if in model after update deleted path, we search for new path that was added
                    else:
                        # Some hack for the case of renaming prim:
                        # if only one new path and one old path and their parents are the same,
                        # I assume that it's renaming.
                        # Actually it's error pron.
                        # Because for example if in one SdfChangeBlock we del one prim
                        # and also add new one not related - we can get the wrong behavior, but
                        # there aren't better solution for now.
                        # Or there can be more than one renaming.
                        if len(new_paths) == 1 and len(old_paths) == 1 :#and \
                            #Sdf.Path(list(new_paths)[0]).GetParentPath() == Sdf.Path(path).GetParentPath():
                            new_path = list(new_paths)[0]
                            self.force_load_children_to_path(new_path)
                            if self._index.ContainsPath(new_path):
                                newProxy = self._index.GetProxy(new_path)
                                newRow = self._index.GetRow(newProxy)
                                for i in range(self.columnCount(QtCore.QModelIndex())):
                                    fromIndices.append(index)
                                    toIndices.append(self.createIndex(
                                        newRow, index.column(), newProxy))
                        else:
                            fromIndices.append(index)
                            toIndices.append(QtCore.QModelIndex())
                self.changePersistentIndexList(fromIndices, toIndices)

        self.needUpdateStage.emit()
        
    def set_time_code(self, time_code):
        # removed immediate update layout so don't slow down animation play
        #with self.LayoutChangedContext(self):
        self._time_code = time_code

    def data(self, modelIndex, role=QtCore.Qt.DisplayRole):
        if not (modelIndex.isValid()):
            return None

        if role in [QtCore.Qt.DisplayRole,
                    QtCore.Qt.ForegroundRole,
                    QtCore.Qt.DecorationRole,
                    QtCore.Qt.FontRole,
                    QtCore.Qt.EditRole,
                    QtCore.Qt.ToolTipRole,
                    QtCore.Qt.BackgroundRole,
                    IsOnEditTargetRole]:
            return self._columns[modelIndex.column()].data(modelIndex, role)

        return HierarchyBaseModel.data(self, modelIndex, role)

    def setData(self, modelIndex, value, role=QtCore.Qt.EditRole):
        return self._columns[modelIndex.column()].setData(modelIndex, value, role)

    def flags(self, modelIndex):
        return self._columns[modelIndex.column()].flags(modelIndex)

    def rowCount(self, parent):
        if not self._IsStageValid():
            return 0

        if not parent.isValid():
            return 1

        parentProxy = parent.internalPointer()

        if not self.show_all_flag:
            prim = parentProxy.GetPrim()
            kind = prim.GetMetadata('kind')

            if kind == 'component':
                return 0

        return self._index.GetChildCount(parentProxy)

    def getColumnIndex(self, header_name):
        for i in range(self._columns.count()):
            if header_name == self._columns[i].header_name:
                return i
        return None

    def updateHighlighted(self, highlighted):
        self.highlighted = highlighted

        self.repaintView.emit()

    def headerData(self, section, orientation, role):
        data = self._columns[section].headerData(role)
        if data is not None:
            return data
        else:
            return super(UsdQtModel, self).headerData(
                section, orientation, role)

    def set_columns(self, columns):
        self.beginResetModel()
        self._columns = columns
        self.endResetModel()

    def add_column(self, column_type, header_name, display_name, *args, **kwargs):
        self.beginResetModel()
        column = self._columns.add_column(column_type, header_name, display_name, *args, **kwargs)
        self.endResetModel()
        return column

    def remove_column(self, index):
        self.beginResetModel()
        self._columns.remove_column(index)
        self.endResetModel()

    def remove_column_by_name(self, header_name):
        self.beginResetModel()
        self._columns.remove_column_by_name(header_name)
        self.endResetModel()

    def GetIndexForPath(self, path):
        if self._index and self._index.ContainsPath(path):
            proxy = self._index.GetProxy(path)
            row = self._index.GetRow(proxy)
            return self.createIndex(row, 1, proxy)
        return None

    def force_load_children_to_path(self, path):
        if not self._index:
            return

        sdfPath = Sdf.Path(path)
        prefixes = sdfPath.GetPrefixes()
        # this added to reload root prim children
        prefixes.insert(0, Sdf.Path('/'))

        # get closest loaded
        prim_prefix = None

        for prefix in prefixes[::-1]:
            if self._index.ContainsPath(prefix):
                prim_prefix = prefix
                break

        if not prim_prefix:
            return

        i = prefixes.index(prim_prefix)

        for prefix in prefixes[i:]:
            prim_index = self.GetIndexForPath(prefix)
            if prim_index and prim_index.isValid():
                proxy = prim_index.internalPointer()
                for i in range(self.rowCount(prim_index)):
                    child = self._index.GetChild(proxy, i)

    def supportedDropActions(self):
        return QtCore.Qt.CopyAction

    def mimeTypes(self):
        return ["application/x-sdfpaths"]

    def mimeData(self, indexes):
        mimedata = QtCore.QMimeData()
        paths = []
        paths_to_move = []
        data = ''
        for index in indexes:
            if index.column() != 0:
                continue
            proxy = index.internalPointer()
            prim = proxy.GetPrim()
            path = prim.GetPath()
            paths.append(path)
            data += path.pathString + '\n'

            # if index.data(IsOnEditTargetRole):
            paths_to_move.append(path)

        value = '\n'.join([path.pathString for path in paths] )
        mimedata.setData("application/x-sdfpaths", six.b(value))
        mimedata.paths_list = paths_to_move
        mimedata.all_paths_list = paths
        return mimedata

    def columns_controller(self):
        return self._columns

    def get_stage(self):
        return self._stage


class FilterModel(HierarchyStandardFilterModel):
    needSelectIndexes = QtCore.Signal(list)

    def __init__(self, parent=None):
        HierarchyStandardFilterModel.__init__(self, parent)
        self._show_inactive = True
        self._show_undefined = True
        self._show_abstract = False

    def set_show_flags(self, show_inactive, show_undefined, show_abstract):
        self._show_inactive = show_inactive
        self._show_undefined = show_undefined
        self._show_abstract = show_abstract

    def setSourceModel(self, source_model):
        QtCore.QSortFilterProxyModel.setSourceModel(self, source_model)
        self.needSelectIndexes.connect(source_model.needSelectIndexes)

    def GetIndexForPath(self, path):
        model = self.sourceModel()
        source_index = model.GetIndexForPath(path)
        if source_index:
            return self.mapFromSource(source_index)

    def _GetPrimForIndex(self, index):
        model = self.sourceModel()
        source_index = self.mapToSource(index)
        if source_index:
            return model._GetPrimForIndex(source_index)

    def _IsStageValid(self):
        return self.sourceModel()._IsStageValid()

    def getColumnIndex(self, title):
        return self.sourceModel().getColumnIndex(title)

    def columns_controller(self):
        return self.sourceModel().columns_controller()

    def supportedDropActions(self):
        return self.sourceModel().supportedDropActions()

    def mimeTypes(self):
        return self.sourceModel().mimeTypes()

    def mimeData(self, indexes):
        model = self.sourceModel()
        source_indexes = [self.mapToSource(index) for index in indexes]
        if source_indexes:
            return model.mimeData(source_indexes)

    def ResetStage(self, stage):
        self.sourceModel().ResetStage(stage)

    def get_stage(self):
        return self.sourceModel().get_stage()

    def SetPathStartsWithFilter(self, substring):
        if substring == "":
            self._filterCacheActive = False
        else:
            self._filterCache.ApplyPathStartsWithFilter(self.sourceModel().GetRoot(),
                                                      substring,
                                                      self._filterCachePredicate)
            self._filterCacheActive = True
        self.invalidateFilter()

if __name__ == '__main__':
    import sys
    import os
    from ._Qt import QtWidgets
    app = QtWidgets.QApplication(sys.argv)
    dir = os.path.split(__file__)[0]
    path = os.path.join(
        dir, 'testenv', 'testUsdQtHierarchyModel', 'simpleHierarchy.usda')
    stage = Usd.Stage.Open(path)
    model = HierarchyModel(stage)
    search = QtWidgets.QLineEdit()

    tv = QtWidgets.QTreeView()
    tv.setModel(model)

    tv.show()

    sys.exit(app.exec_())
