import { mat4, quat, vec2, vec3 } from "gl-matrix"
import { Component } from "../../Component"
import { GameObject } from "../../GameObject"
import { GLTFMesh, GLTFRawAnimation } from "../GLTFLoader"
import { assert } from "../utils"
import { GLTFSkins } from "./Skins"

interface Animation {
    name?: string,
    inputs: Float32Array,
    channels: Float32Array[][],
    duration: number,
}

interface Clip {
    animation: Animation

    t: number
    speed: number
    rootMotion?: RootMotion
    out: Float32Array[]
}

interface RootMotion {
    jointIndex: number

    motions: vec3[]

    motion: vec3
    totalMotion: vec3

    positionPrev: vec3
    position: vec3
}

export type AnimationNodeInput = {
    out: Float32Array[]
    rootMotion?: {
        position: vec3
        motion: vec3
    },
}

export class AnimationBlender {
    inputs: [AnimationNodeInput, AnimationNodeInput]
    out: Float32Array[]
    factor: number

    rootMotion?: {
        position: vec3,
        motion: vec3,
    }

    constructor(inputs: [AnimationNodeInput, AnimationNodeInput], rootMotion = false) {
        this.inputs = inputs
        this.out = Array(this.inputs[0].out.length)
        for (let i = 0; i < this.out.length; i += 3) {
            this.out[i + 0] = new Float32Array(4)
            this.out[i + 1] = new Float32Array(3)
            this.out[i + 2] = new Float32Array(3)
        }
        this.factor = 0.5

        if (rootMotion) {
            this.rootMotion = {
                position: vec3.create(),
                motion: vec3.create(),
            }
        }
    }

    blend() {
        const o = this.out
        const a = this.inputs[0].out
        const b = this.inputs[1].out

        for (let i = 0; i < this.out.length; i += 3) {
            quat.slerp(
                o[i + 0],
                a[i + 0],
                b[i + 0],
                this.factor
            )
            vec3.lerp(
                o[i + 1],
                a[i + 1],
                b[i + 1],
                this.factor
            )
            vec3.lerp(
                o[i + 2],
                a[i + 2],
                b[i + 2],
                this.factor
            )
        }
        
        if (this.rootMotion !== undefined) {
            const rootMotion0 = this.inputs[0].rootMotion!
            const rootMotion1 = this.inputs[1].rootMotion!

            vec3.lerp(
                this.rootMotion.position,
                rootMotion0.position,
                rootMotion1.position,
                this.factor
            )

            vec3.lerp(
                this.rootMotion.motion,
                rootMotion0.motion,
                rootMotion1.motion,
                this.factor
            )
        }
    }
}

export class Animator extends Component {
    gltf: GLTFMesh
    skins: GLTFSkins
    animations: Animation[]

    clips: Clip[]

    currentClips: number[] = []

    rootJointName?: string
    rootMotionDisplacement?: vec3
    rootMotionAxes?: vec3

    constructor(parent: GameObject, gltf: GLTFMesh, animations: GLTFRawAnimation[], rootJointName?: string, rootMotionAxes?: vec3) {
        super(parent)

        this.gltf = gltf
        this.skins = gltf.skins

        this.rootJointName = rootJointName
        this.rootMotionAxes = rootMotionAxes

        this.animations = animations.map(rawAnimation => {
            const check = new Uint8Array(this.skins.jointNodes.length).fill(0b111)

            const channels: Animation['channels'] = Array(this.skins.jointNodes.length * 3)

            rawAnimation.channels.forEach(channel => {
                const nodeIdx = this.skins.jointNodes.findIndex(n => n.name === channel.target.name)
                assert(nodeIdx !== -1)

                if (channel.target.path === 'rotation') {
                    check[nodeIdx] ^= 0b001
                    channels[nodeIdx * 3 + 0] = Array(rawAnimation.inputs.length).fill(null).map((_, i) => {
                        return channel.output.subarray(i * 4, i * 4 + 4)
                    })
                }
                else if (channel.target.path === 'translation') {
                    check[nodeIdx] ^= 0b010
                    channels[nodeIdx * 3 + 1] = Array(rawAnimation.inputs.length).fill(null).map((_, i) => {
                        return channel.output.subarray(i * 3, i * 3 + 3)
                    })
                }
                else if (channel.target.path === 'scale') {
                    check[nodeIdx] ^= 0b100
                    channels[nodeIdx * 3 + 2] = Array(rawAnimation.inputs.length).fill(null).map((_, i) => {
                        return channel.output.subarray(i * 3, i * 3 + 3)
                    })
                }
            })

            // asserts that the animation modifies every jointNode's rotation, translation and scale
            assert(check.reduce((acc, x) => acc | x, 0) === 0)

            return {
                name: rawAnimation.name,
                duration: rawAnimation.duration,
                inputs: rawAnimation.inputs,
                channels: channels,
            }
        })

        this.clips = []
    }

    addClip(animationIdx: number) {
        const animation = this.animations[animationIdx]

        const out = Array(this.skins.jointNodes.length * 3)
        for (let i = 0; i < out.length; i += 3) {
            out[i + 0] = new Float32Array(4)
            out[i + 1] = new Float32Array(3)
            out[i + 2] = new Float32Array(3)
        }

        const clip: Clip = {
            animation,
            t: 0,
            speed: 1,
            out,
        }

        if (this.rootJointName !== undefined) {
            const rootIndex = this.skins.jointNodes.findIndex(node => node.name === this.rootJointName)

            assert(rootIndex !== -1)
            assert(this.skins.jointParents[rootIndex] === -1)

            this.rootMotionDisplacement = vec3.create()

            const tmp = mat4.create()

            const motions = Array(animation.inputs.length)
            for (let i = 0; i < animation.inputs.length; i++) {
                mat4.fromRotationTranslationScale(
                    tmp,
                    animation.channels[rootIndex * 3 + 0][i],
                    animation.channels[rootIndex * 3 + 1][i],
                    animation.channels[rootIndex * 3 + 2][i],
                )
                const v = vec3.fromValues(0, 0, 0)
                vec3.transformMat4(v, v, tmp)
                motions[i] = v
            }
            clip.rootMotion = {
                jointIndex: rootIndex,
                motion: vec3.create(),
                motions,
                totalMotion: vec3.sub(vec3.create(), motions[motions.length - 1], motions[0]),
                position: vec3.create(),
                positionPrev: vec3.create(),
            }
        }

        return this.clips.push(clip)
    }

    play(indices: number[]) {
        this.currentClips = indices

        for (const index of indices) {
            const clip = this.clips[index]

            clip.t = 0

            if (clip.rootMotion !== undefined) {
                vec3.set(clip.rootMotion.position, 0, 0, 0)
                vec3.set(clip.rootMotion.positionPrev, 0, 0, 0)
                vec3.set(clip.rootMotion.motion, 0, 0, 0)
            }

            this.skins.updateJointForwardMatrices(clip.out)
            this.calcAnimation(index, clip.out)
            this.updateJoints(clip.out)
        }
    }

    private updateNode(node: AnimationNodeInput, acc: Set<AnimationNodeInput>) {
        if (node instanceof AnimationBlender) {
            if (node.inputs[0] instanceof AnimationBlender) {
                if (acc.has(node.inputs[0]) === false) {
                    this.updateNode(node.inputs[0], acc)
                }
            }
            if (node.inputs[1] instanceof AnimationBlender) {
                if (acc.has(node.inputs[1]) === false) {
                    this.updateNode(node.inputs[1], acc)
                }
            }
            node.blend()
        }
        else {
            throw new Error('not implemented')
        }
    }

    update(d: AnimationNodeInput, delta: number) {
        for (const index of this.currentClips) {
            const clip = this.clips[index]
            const animation = clip.animation
            clip.t += delta * clip.speed
            if (clip.t > animation.duration) {
                const n = Math.floor(clip.t / animation.duration)
                clip.t = clip.t % animation.duration

                if (clip.rootMotion) {
                    clip.rootMotion.positionPrev[0] -= clip.rootMotion.totalMotion[0] * n * this.rootMotionAxes![0]
                    clip.rootMotion.positionPrev[1] -= clip.rootMotion.totalMotion[1] * n * this.rootMotionAxes![1]
                    clip.rootMotion.positionPrev[2] -= clip.rootMotion.totalMotion[2] * n * this.rootMotionAxes![2]
                }
            }

            this.calcAnimation(index, clip.out)
        }

        this.updateNode(d, new Set())
        this.skins.updateJointForwardMatrices(d.out)
        this.skins.updateJoints(mat4.fromTranslation(mat4.create(), vec3.scale(vec3.create(), d.rootMotion!.position, -1)))

        vec3.copy(this.rootMotionDisplacement!, d.rootMotion!.motion)
    }

    private calcAnimation(index: number, out?: Float32Array[]) {
        const clip = this.clips[index]
        const animation = clip.animation
        const duration = animation.duration
        const numKeyframes = animation.inputs.length
        let t = animation.inputs[0] + (clip.t % duration)
        let inputIdx0 = (animation.inputs.findIndex(v => v > t) + numKeyframes - 1) % numKeyframes
        if (inputIdx0 === numKeyframes - 1) {
            inputIdx0 = 0
            t = animation.inputs[0]
        }
        const inputIdx1 = inputIdx0 + 1

        const s1 = (t - animation.inputs[inputIdx0]) / (animation.inputs[inputIdx1] - animation.inputs[inputIdx0])

        if (clip.rootMotion) {
            vec3.lerp(
                clip.rootMotion.position,
                clip.rootMotion.motions[inputIdx0],
                clip.rootMotion.motions[inputIdx1],
                s1
            )
            clip.rootMotion.position[0] *= this.rootMotionAxes![0]
            clip.rootMotion.position[1] *= this.rootMotionAxes![1]
            clip.rootMotion.position[2] *= this.rootMotionAxes![2]
            vec3.sub(
                clip.rootMotion.motion,
                clip.rootMotion.position,
                clip.rootMotion.positionPrev
            )
            vec3.copy(
                clip.rootMotion.positionPrev,
                clip.rootMotion.position
            )
        }

        if (out !== undefined) {
            for (let i = 0; i < animation.channels.length; i += 3) {
                {
                    const channel = animation.channels[i + 0]
                    const a = channel[inputIdx0]
                    const b = channel[inputIdx1]
                    const o = out[i + 0]
                    quat.slerp(o, a, b, s1)
                }
                {
                    const channel = animation.channels[i + 1]
                    const a = channel[inputIdx0]
                    const b = channel[inputIdx1]
                    const o = out[i + 1]
                    vec3.lerp(o, a, b, s1)
                }
                {
                    const channel = animation.channels[i + 2]
                    const a = channel[inputIdx0]
                    const b = channel[inputIdx1]
                    const o = out[i + 2]
                    vec3.lerp(o, a, b, s1)
                }
            }
        }
    }

    private updateJoints(animation: Float32Array[]) {
        this.skins.updateJointForwardMatrices(animation)
        this.skins.updateJoints()
    }
}
