import { mat4, quat, vec3 } from "gl-matrix"
import { Component } from "../../Component"
import { GameObject } from "../../GameObject"
import { FREE_INDEX0, Graphics } from "../../Graphics"
import { CameraComponent } from "../Camera"
import { RenderableObject } from "../RenderableObject"
import { ShadowMap } from "../ShadowMap"
import { addSetUniform, setGbuffer } from "../utils"
import { getDirectionalLightFS } from "./DirectionalLightFS"

export type DirectionalLightOptions = {
    color: vec3
    intensity: number
    shadowMapsNum?: number
}

export class DirectionalLightComponent extends Component {
    private gl: WebGL2RenderingContext

    private ro: RenderableObject

    shadowMapsNum?: number
    shadowMapMatsBuffer?: Float32Array
    shadowMapMats?: mat4[]
    shadowMapTransforms?: mat4[]
    shadowMapProjections?: mat4[]
    shadowMaps?: ShadowMap[]
    shadowMapTexs?: WebGLTexture[]

    constructor(parent: GameObject, options: DirectionalLightOptions) {
        super(parent)

        const gl = Graphics.context
        this.gl = gl

        const defines: Record<string, unknown> = {}

        this.shadowMapsNum = options.shadowMapsNum
        if (options.shadowMapsNum) {
            defines['USE_SHADOW_MAP'] = ''
            defines['SHADOW_MAPS_NUM'] = options.shadowMapsNum
            this.hasShadowMap = options.shadowMapsNum
        }

        const vShader = Graphics.shaders.getShader('basicVS')
        const fShader = Graphics.shaders.compileShaderFromStr(getDirectionalLightFS(options.shadowMapsNum), {type: 'fragment', defines})

        const ro = new RenderableObject(gl, vShader, fShader, gl.TRIANGLE_STRIP, 4)
        this.ro = ro
        ro.useProgram()
        ro.addUniform('uView', 'mat4')
        ro.addUniform('uOrthographic', 'bool')
        ro.addUniform('uLightDirection', 'vec3')
        addSetUniform(ro, 'uLightColor', 'vec3', options.color)
        addSetUniform(ro, 'uLightIntensity', 'float', options.intensity)

        if (options.shadowMapsNum) {
            ro.addUniform('uShadowMapMats', 'mat4')
            addSetUniform(ro, 'uShadowMaps', 'int', new Int32Array(options.shadowMapsNum).map((_, i) => FREE_INDEX0 + i))
        }

        setGbuffer(ro)
    }

    setShadowMaps(shadowMaps: ShadowMap[] | undefined) {
        this.shadowMaps = shadowMaps
        this.shadowMapTexs = shadowMaps?.map(sm => sm.texture)
        this.shadowMapMatsBuffer = shadowMaps ? new Float32Array(shadowMaps.length * 16) : undefined
        this.shadowMapMats = shadowMaps?.map((_, i) => this.shadowMapMatsBuffer!.subarray(i * 16, i * 16 + 16))
        this.shadowMapProjections = shadowMaps?.map(_ => mat4.create())
        this.shadowMapTransforms = shadowMaps?.map(_ => mat4.create())
    }

    setFromCamera(camera: CameraComponent, planeZs: number[]) {
        if (camera.data.type === 'perspective') {
            const ty = Math.tan(camera.data.fovy / 2)
            const tx = ty * camera.data.aspect

            const cameraMat = camera.transform.matrix
            const q = quat.conjugate(quat.create(), this.parent.transform.rotation)

            const minMax: [vec3, vec3][] = Array(planeZs.length)
            for (let i = 0; i < planeZs.length; i++) {
                const z = planeZs[i]
                const p = vec3.fromValues(tx*z, ty*z, -z)
                vec3.transformMat4(p, p, cameraMat)
                vec3.transformQuat(p, p, q)
                const m0 = vec3.clone(p)
                const m1 = vec3.clone(p)
                vec3.set(p, tx*z, -ty*z, -z)
                vec3.transformMat4(p, p, cameraMat)
                vec3.transformQuat(p, p, q)
                vec3.min(m0, m0, p)
                vec3.max(m1, m1, p)
                vec3.set(p, -tx*z, ty*z, -z)
                vec3.transformMat4(p, p, cameraMat)
                vec3.transformQuat(p, p, q)
                vec3.min(m0, m0, p)
                vec3.max(m1, m1, p)
                vec3.set(p, -tx*z, -ty*z, -z)
                vec3.transformMat4(p, p, cameraMat)
                vec3.transformQuat(p, p, q)
                vec3.min(m0, m0, p)
                vec3.max(m1, m1, p)
                minMax[i] = [m0, m1]
            }

            for (let i = 0; i < this.shadowMaps!.length; i++) {
                const p = vec3.fromValues(
                    (minMax[i][0][0] + minMax[i][1][0]) * .5,
                    (minMax[i][0][1] + minMax[i][1][1]) * .5,
                    (minMax[i][0][2] + minMax[i][1][2]) * .5,
                )
                vec3.transformQuat(p, p, this.parent.transform.rotation)
                mat4.fromRotationTranslation(this.shadowMapTransforms![i], this.parent.transform.rotation, p)

                vec3.min(minMax[i][0], minMax[i][0], minMax[i + 1][0])
                vec3.max(minMax[i][1], minMax[i][1], minMax[i + 1][1])

                const d = vec3.sub(vec3.create(), minMax[i][1], minMax[i][0])
                vec3.scale(d, d, 1/2)

                mat4.ortho(this.shadowMapProjections![i], -d[0], d[0], -d[1], d[1], -d[2] - 30, d[2])
            }
        }
        else if (camera.data.type === 'orthographic') {
            const cameraMat = camera.transform.matrix
            const q = quat.conjugate(quat.create(), this.parent.transform.rotation)

            const w = camera.data.width / 2
            const h = camera.data.width / camera.data.aspect / 2

            const p = vec3.create()
            const minMax: [vec3, vec3][] = Array(planeZs.length)
            for (let i = 0; i < planeZs.length; i++) {
                const z = planeZs[i]
                const p = vec3.fromValues(w, h, -z)
                vec3.transformMat4(p, p, cameraMat)
                vec3.transformQuat(p, p, q)
                const m0 = vec3.clone(p)
                const m1 = vec3.clone(p)
                vec3.set(p, w, -h, -z)
                vec3.transformMat4(p, p, cameraMat)
                vec3.transformQuat(p, p, q)
                vec3.min(m0, m0, p)
                vec3.max(m1, m1, p)
                vec3.set(p, -w, h, -z)
                vec3.transformMat4(p, p, cameraMat)
                vec3.transformQuat(p, p, q)
                vec3.min(m0, m0, p)
                vec3.max(m1, m1, p)
                vec3.set(p, -w, -h, -z)
                vec3.transformMat4(p, p, cameraMat)
                vec3.transformQuat(p, p, q)
                vec3.min(m0, m0, p)
                vec3.max(m1, m1, p)
                minMax[i] = [m0, m1]
            }

            for (let i = 0; i < this.shadowMaps!.length; i++) {
                const p = vec3.fromValues(
                    (minMax[i][0][0] + minMax[i][1][0]) * .5,
                    (minMax[i][0][1] + minMax[i][1][1]) * .5,
                    (minMax[i][0][2] + minMax[i][1][2]) * .5,
                )
                vec3.transformQuat(p, p, this.parent.transform.rotation)
                mat4.fromRotationTranslation(this.shadowMapTransforms![i], this.parent.transform.rotation, p)

                vec3.min(minMax[i][0], minMax[i][0], minMax[i + 1][0])
                vec3.max(minMax[i][1], minMax[i][1], minMax[i + 1][1])

                const d = vec3.sub(vec3.create(), minMax[i][1], minMax[i][0])
                vec3.scale(d, d, 1/2)

                mat4.ortho(this.shadowMapProjections![i], -d[0], d[0], -d[1], d[1], -d[2] - 30, d[2])
            }
        }
    }

    update(camera: CameraComponent) {
        if (this.shadowMapProjections) {
            const dists = [camera.data.near, camera.data.near + 4, camera.data.near + 8, camera.data.near + 16, camera.data.near + 32]
            this.setFromCamera(camera, dists)
        }
    }

    render(viewMatrix?: mat4, projectionMatrix?: mat4) {
        const ro = this.ro
        ro.useProgram()
        ro.setUniform('uLightDirection', this.parent.transform.getForward())
        ro.setUniform('uView', viewMatrix)
        ro.setUniform('uOrthographic', Graphics.camera?.data.type === 'orthographic')

        if (this.hasShadowMap) {
            const gl = this.gl
            for (let i = 0; i < this.shadowMaps!.length; i++) {
                gl.activeTexture(gl.TEXTURE0 + FREE_INDEX0 + i)
                gl.bindTexture(gl.TEXTURE_2D, this.shadowMapTexs![i])
                const projection = this.shadowMapProjections![i]
                const transform = mat4.invert(this.shadowMapMats![i], this.shadowMapTransforms![i])
                mat4.mul(transform, projection, transform)
            }
            ro.setUniform('uShadowMapMats', this.shadowMapMatsBuffer!)
        }

        ro.render()
    }
}
