import React from "react"
import { gsap } from "gsap"
import * as THREE from "three"
//@ts-ignore
import vertex from "./glsl/vertex.glsl"
//@ts-ignore
import fragment from "./glsl/fragment.glsl"
import { UniformsType } from "@canvas-components/canvas-types"
import { useThree } from "@react-three/fiber"

const ExtendedMaterial: React.FC<any> = React.forwardRef(
  ({ materialType, uniforms, setCompiled, mesh, ...props }, ref) => {
    const onBeforeCompile = (shader) => {
      // console.log(shader)
      if (!shader.__modified) {
        shader.uniforms.uTime = {
          value: 0.0, // Initialize with a default value
        }
        if (!uniforms) console.error("uniforms not defined")
        for (const uniform of uniforms) {
          if (uniform.type === "colour") {
            shader.uniforms[uniform.id] = {
              value: new THREE.Color(uniform.value as string),
            }
          } else if (uniform.type === "float") {
            shader.uniforms[uniform.id] = {
              value: uniform.value,
            }
          } else if (uniform.type === "vec3") {
            shader.uniforms[uniform.id] = {
              value: new THREE.Vector3(
                uniform.value[0],
                uniform.value[1],
                uniform.value[2]
              ),
            }
          } else if (uniform.type === "sampler2D") {
            shader.uniforms[uniform.id] = {
              value: uniform.value,
            }
          }
        }

        addTopShaders(shader, uniforms)
        addBottomShaders(shader)
        shaderRef.current.userData = { shader }
        setCompiled()
      }

      shader.__modified = true
    }

    const shaderRef = React.useRef<THREE.Material>()

    const { invalidate } = useThree()

    const update = (action, v, callback) => {
      gsap.to(shaderRef.current.userData.shader.uniforms[action], {
        onStart: () => {
          callback && callback("start")
        },
        value: v,
        ease: "none",
        duration: 0.2,
        onUpdate: () => {
          invalidate()
        },
        onComplete: () => {
          callback && callback("complete")
        },
      })
    }

    React.useImperativeHandle(ref, () => ({
      update,
    }))

    return React.createElement(materialType, {
      ref: shaderRef,
      onBeforeCompile: onBeforeCompile,
      ...props,
    })
  }
)

const addTopShaders = (
  shader: { vertexShader: string; fragmentShader: string },
  uniforms: UniformsType[]
) => {
  const uniformDeclarations = uniforms
    .map((uniform) => {
      // console.log(uniform.type)
      const type = uniform.type === "colour" ? "vec3" : uniform.type
      return `uniform ${type} ${uniform.id};`
    })
    .join("\n")

  shader.vertexShader = `
    varying vec2 vUv;
    varying vec3 vNormal_1;
    ${shader.vertexShader}
  `

  shader.fragmentShader = `
    varying vec2 vUv;
    varying vec3 vNormal_1;
    ${uniformDeclarations}
    uniform float uTime;
    vec4 toLinear(vec4 sRGB)
    {
      bvec3 cutoff = lessThan(sRGB.rgb, vec3(0.04045));
      vec3 higher = pow((sRGB.rgb + vec3(0.055))/vec3(1.055), vec3(2.4));
      vec3 lower = sRGB.rgb/vec3(12.92);
      return vec4(mix(higher, lower, cutoff), sRGB.a);
    }

    ${shader.fragmentShader}
  `
}

const addBottomShaders = (shader: {
  vertexShader: string
  fragmentShader: string
}) => {
  shader.vertexShader = shader.vertexShader.replace(
    "#include <uv_vertex>",
    vertex
  )
  shader.fragmentShader = shader.fragmentShader.replace(
    "#include <output_fragment>",
    // "#include <dithering_fragment>",

    fragment
  )
}

export default ExtendedMaterial
