
import json
import os
import re

import maya.cmds as cmds
import pymel.core as pm
import wizart.maya_qtutils
import wref_resolver
from Qt import QtCore, QtGui, QtWidgets
from .wref_context_manager import utils as wref_utils

from .utils import check_alusdmaya_plugin
ROOT_PATH = "$PROJECT_ROOT_PATH"

CONFIG_PATH = os.path.join(ROOT_PATH, ".config")

RESOLUTION_CONFIG_PATH = os.path.join(CONFIG_PATH, "resolution.yml")
PROJECT_CONFIG_PATH = os.path.join(CONFIG_PATH, "project.yml")
ANIM_PREP_CONFIG = os.path.join(CONFIG_PATH, "anim_prep.yml")
W_VAR = "$W"
LAST_ELEMENT = -1

SHOTS_NAME_CONVENTION = dict(
    shots="shots",
    sequences="sequences",
    episodes="episodes",
)

SCENES_NAME_CONVENTION = dict(
    shots="scenes",
    sequences="episodes",
    episodes="series",
)


def get_wref_root():
    expanded_root = os.path.expandvars(ROOT_PATH)
    expanded_w = os.path.expandvars(W_VAR)
    local_path = expanded_root.replace(expanded_w, "")
    return "wref:/{}".format(local_path)


def correct_scene_resolution():
    resolution = resolution_from_config()
    width = resolution.get("width", 2048)
    height = resolution.get("height", 858)
    pm.setAttr("defaultResolution.width", width)
    pm.setAttr("defaultResolution.height", height)


def pin_rigs_for_scene(path, stage):
    all_refs = wref_utils.get_from_usd_paths(wref_resolver.resolve(path))
    config_pin = {}
    for ref in all_refs:
        if os.path.splitext(wref_utils.get_basepath(ref))[LAST_ELEMENT] == '.mb':
            asset_info = wref_resolver.AssetInfo()
            wref_resolver.resolve(ref, asset_info)
            config_pin[wref_utils.get_basepath(ref)] = asset_info.version

    if config_pin:
        config_dump = json.dumps(dict(pin=config_pin))
        import usd_wref_resolver
        # import resolver python binding first, before quering it from stage
        ctx = stage.GetPathResolverContext()
        ctx.config = config_dump

        pm.system.fileInfo["wref_resolver_config"] = json.dumps(
            dict(pin=config_pin))

@check_alusdmaya_plugin
def start_from_anim_prep_scene(episode, scene):
    name_exp = "^([a-zA-Z])"
    formatted_episode = "ep{}".format(episode) if re.search(
        name_exp, episode) is None else episode
    formatted_scene = "sc{}".format(scene) if re.search(
        name_exp, scene) is None else scene
    anim_prep_path_template = "{wref_root}/episodes/{episode}/scenes/{scene}/def/" \
                              "?path={formatted_episode}_{formatted_scene}.usd&tag=latest"
    anim_prep_path = anim_prep_path_template.format(wref_root=get_wref_root(), episode=episode,
                                                    scene=scene, formatted_episode=formatted_episode,
                                                    formatted_scene=formatted_scene)
    if anim_prep_path:
        pm.mel.performNewScene(0) #always start with fresh scene, and call mel peformNewScene to ask user to save scene
        node_name = cmds.AL_usdmaya_ProxyShapeImport(file=anim_prep_path, name="{episode}_{scene}"
                                                     .format(episode=formatted_episode, scene=formatted_scene))
        if node_name:
            load_stage(anim_prep_path, node_name)

@check_alusdmaya_plugin
def start_from_anim_prep_shot(anim_prep_path, name):
    pm.mel.performNewScene(0) #always start with fresh scene, and call mel peformNewScene to ask user to save scene
    node_name = cmds.AL_usdmaya_ProxyShapeImport(
        file=anim_prep_path, name=name)
    if node_name:
        load_stage(anim_prep_path, node_name)

@check_alusdmaya_plugin
def load_stage(anim_prep_path, node_name):
    import AL.usdmaya
    node_name = node_name[0]
    proxy_shape = AL.usdmaya.ProxyShape.getByName(node_name)
    stage = proxy_shape.getUsdStage()
    start_time = stage.GetStartTimeCode()
    end_time = stage.GetEndTimeCode()
    fps = stage.GetFramesPerSecond()
    cmds.currentUnit(time="{:1.0f}fps".format(fps))
    # until we have metadata support in stage, just assume we need moblur frames +1 to stage framerange
    cmds.playbackOptions(e=True, animationStartTime=start_time - 1.0)
    cmds.playbackOptions(e=True, animationEndTime=end_time + 1.0)
    cmds.playbackOptions(e=True, min=start_time)
    cmds.playbackOptions(e=True, max=end_time)
    pin_rigs_for_scene(anim_prep_path, stage)
    correct_scene_resolution()
    customLayerData = stage.GetPseudoRoot().GetMetadata("customLayerData")
    if customLayerData and "sound" in customLayerData:
        pm.createReference(customLayerData["sound"].path)
        for sound_node in pm.ls(type="audio"):
            # offset sound to start time
            sound_node.offset.set(start_time)


def get_config(path):
    import yaml
    import re
    yaml.add_constructor('!regexp', lambda l,
                         n: re.compile(l.construct_scalar(n)))
    data = {}
    if (os.path.exists(path)):
        try:
            with open(path, 'r') as f:
                doc = yaml.Loader(f)
                try:
                    if not doc.check_data():
                        print("Cannot load yaml file %s" % path)
                    data = doc.get_data()
                finally:
                    doc.dispose()
        except IOError:
            pass
    return data


def resolution_from_config():
    return get_config(os.path.expandvars(RESOLUTION_CONFIG_PATH))


def get_project_config():
    return get_config(os.path.expandvars(PROJECT_CONFIG_PATH))


def get_anim_prep_config():
    return get_config(os.path.expandvars(ANIM_PREP_CONFIG))


class StartFromAnimPrepWidget(QtWidgets.QWidget):
    def __init__(self, parent=None):
        super(StartFromAnimPrepWidget, self).__init__(parent=parent)
        self.setWindowFlags(QtCore.Qt.Window)
        self.setWindowTitle('Start from Anim Prep')
        self.setLayout(QtWidgets.QVBoxLayout())
        self.project_config = get_project_config()
        self.is_serial = self.project_config.get("is_series", False)
        self.name_convention = SHOTS_NAME_CONVENTION if self.project_config.get(
            "use_shots_name_convention", False) else SCENES_NAME_CONVENTION
        self.anim_prep_config = get_anim_prep_config().get("start_from_anim_prep", {})

        self.fs_model = QtWidgets.QFileSystemModel()
        self.fs_model.setRootPath('{}/'.format(os.environ.get("W")))
        hlt = QtWidgets.QHBoxLayout()

        if self.is_serial:
            hlt.addWidget(QtWidgets.QLabel(
                self.name_convention.get("episodes").capitalize() + ":"))
            self.episode_select = QtWidgets.QComboBox()
            self.episode_select.setModel(self.fs_model)
            episode_list_path = os.path.expandvars(
                self.anim_prep_config.get("episode_ui_selector", "").format(root_path=ROOT_PATH))
            self.episode_select.setRootModelIndex(
                self.fs_model.index(episode_list_path))
            self.episode_select.setMinimumWidth(100)
            self.episode_select.currentIndexChanged.connect(
                self.update_sequence_list)
            hlt.addWidget(self.episode_select)
            hlt.addStretch()

        hlt.addWidget(QtWidgets.QLabel(
            self.name_convention.get("sequences").capitalize() + ":"))
        self.sequences_select = QtWidgets.QComboBox()
        self.sequences_select.setModel(self.fs_model)
        if not self.is_serial:
            sequences_list_path = os.path.expandvars(self.anim_prep_config.get(
                "sequence_ui_selector", "{root_path}/episodes/").format(root_path=ROOT_PATH))
            self.sequences_select.setRootModelIndex(
                self.fs_model.index(sequences_list_path)
            )
        self.sequences_select.setMinimumWidth(100)
        self.sequences_select.currentIndexChanged.connect(
            self.update_shot_list)
        hlt.addWidget(self.sequences_select)
        hlt.addStretch()

        hlt.addWidget(QtWidgets.QLabel(
            self.name_convention.get("shots").capitalize() + ":"))
        self.shots_select = QtWidgets.QComboBox()
        self.shots_select.setModel(self.fs_model)
        self.shots_select.setMinimumWidth(100)
        self.shots_select.currentIndexChanged.connect(
            self.chech_shot)
        hlt.addWidget(self.shots_select)

        self.layout().addLayout(hlt)

        self.import_btn = QtWidgets.QPushButton("Import")
        button_box = QtWidgets.QHBoxLayout()
        button_box.addStretch()
        button_box.addWidget(self.import_btn)
        self.import_btn.released.connect(self.on_import)
        button_box.addStretch()
        self.layout().addLayout(button_box)

        if self.is_serial:
            self.episode_select.setCurrentIndex(0)
        else:
            self.sequences_select.setCurrentIndex(0)

    def update_sequence_list(self):
        sequence_path = os.path.expandvars(self.anim_prep_config.get("sequence_ui_selector", "{root_path}/episodes/{episode}/scenes/").format(
            root_path=ROOT_PATH, episodes=self.episode_select.currentText()))
        self.sequences_select.setRootModelIndex(
            self.fs_model.index(sequence_path))
        self.fs_model.fetchMore(self.fs_model.index(sequence_path))
        self.sequences_select.setCurrentIndex(0)

    def update_shot_list(self):
        shot_path_vars = dict(root_path=ROOT_PATH, sequences=self.sequences_select.currentText())

        if self.is_serial:
            shot_path_vars["episodes"] = self.episode_select.currentText()

        shot_path = os.path.expandvars(self.anim_prep_config.get(
            "shot_ui_selector", "{root_path}/episodes/{sequences}/scenes/").format(**shot_path_vars))

        self.shots_select.setRootModelIndex(
            self.fs_model.index(shot_path)
        )
        self.fs_model.fetchMore(self.fs_model.index(shot_path))
        self.shots_select.setCurrentIndex(0)

    def chech_shot(self):
        self.import_btn.setEnabled(bool(self.shots_select.currentText()))

    def on_import(self):
        if not self.anim_prep_config.keys():
            episode = self.sequences_select.currentText()
            scene = self.shots_select.currentText()
            start_from_anim_prep_scene(episode, scene)
        else:
            sequence = self.sequences_select.currentText()
            shot = self.shots_select.currentText()
            context = dict(sequences=sequence, shots=shot)
            if self.is_serial:
                episode = self.episode_select.currentText()
                context["episodes"] = episode

            anim_prep_path = self.anim_prep_config.get(
                "shot_path_template").format(**context)

            if self.is_serial:
                name = "ep{episode}_sq{sequence}_sh{shot}".format(
                    episode=episode, sequence=sequence, shot=shot)
            else:
                name = "sq{sequence}_sh{shot}".format(sequence=sequence, shot=shot)

            start_from_anim_prep_shot(anim_prep_path, name)


def anim_prep_ui():
    widget = StartFromAnimPrepWidget(
        wizart.maya_qtutils.getMayaWindow(QtWidgets.QMainWindow))
    widget.show()
