import { NumberTuple } from "../types";
import gpu from "./gpu";

export const makeRandomGen = (shape: NumberTuple) =>
  gpu
    .createKernel(function () {
      if (Math.random() > 0.5) {
        return 0.0; // for type main diag
      } else {
        return 1.0; // for type off diag
      }
    })
    .setOutput(shape);

export const makeImageEncoder = (shape: NumberTuple) =>
  gpu
    .createKernel(function (img: number[][], rand: number[][]) {
      const img_val = img[this.thread.y][this.thread.x];
      const rand_val = rand[this.thread.y][this.thread.x];

      if (img_val < 0.5) {
        // img pixel is "black"
        if (rand_val < 0.5) return 1.0;
        else return 0.0;
      } else {
        // img pixel is "grey"
        if (rand_val < 0.5) return 0.0;
        else return 1.0;
      }
    })
    .setOutput(shape);

export const makeShapeEncoder = (shape: NumberTuple) =>
  gpu
    .createKernel(function (img: number[][]) {
      const x = Math.floor(this.thread.x / 2);
      const y = Math.floor(this.thread.y / 2);

      const val = Math.round(img[y][x]);
      let val_inv = 0.0;
      if (val < 0.5) val_inv = 1.0;

      // ToDo: maybe try branchless? helpful on gpu?
      if (this.thread.y % 2 === 0) {
        if (this.thread.x % 2 === 0) {
          // top left
          return val;
        } else {
          // top right
          return val_inv;
        }
      } else {
        if (this.thread.x % 2 === 0) {
          // bottom left
          return val_inv;
        } else {
          // bottom right
          return val;
        }
      }
    })
    .setOutput([2 * shape[0], 2 * shape[1]])
    .setTactic("precision");
