import * as tf from "@tensorflow/tfjs";
import { Player } from "./player";
import { FIELD_SIZE } from "./nardy";
import { MediumAIPlayer_data } from "./MediumAIPlayer_data";
/*****************************************************************************/

function chooseThrowMoves(game, possible_moves) {
  let best_moves = possible_moves[0];

  let best_sum = game.field.getFieldDataForPlayer(
    game.current_player
  ).sum_to_end;
  let best_chips = game.field.getChipsCounterForPlayer(game.current_player);

  for (let moves of possible_moves) {
    let sum = moves.field.getFieldDataForPlayer(game.current_player).sum_to_end;
    let chips = moves.field.getChipsCounterForPlayer(game.current_player);

    if (chips < best_chips) {
      best_moves = moves;
      best_chips = chips;
      best_sum = sum;
    } else if (chips === best_chips) {
      if (sum < best_sum) {
        best_moves = moves;
        best_chips = chips;
        best_sum = sum;
      }
    }
  }

  return best_moves;
}

class SplitL extends tf.layers.Layer {
  constructor() {
    super({});
  }

  computeOutputShape(inputShape) {
    return [
      [inputShape[0], Math.round(inputShape[1] / 2)],
      [inputShape[0], Math.round(inputShape[1] / 2)],
    ];
  }

  call(inputs, kwargs) {
    return tf.split(inputs[0], 2, 1);
  }

  static get className() {
    return "SplitL";
  }
}
tf.serialization.registerClass(SplitL); // Needed for serialization.

const INPUT_SIZE = 432;
const model_hard = await tf.loadLayersModel("/model_web/model.json");
model_hard.compile({ loss: "meanSquaredError", optimizer: "adam" });

const model_uhard = await tf.loadLayersModel("/model_web_uhard/model.json");
model_uhard.compile({ loss: "meanSquaredError", optimizer: "adam" });
/*****************************************************************************/
export class HardAIPlayer {
  constructor(model) {
    this.model = model;
  }

  extractFieldDataRow(inputData, index_shift, player, field) {
    let field_raw_data = field.getDataForPlayer(player);

    for (let i = 0; i < field_raw_data.length; i++) {
      if (field_raw_data[i] > 0) {
        let index = 4 * i + index_shift;

        inputData[index] = 1;

        if (field_raw_data[i] > 1) {
          inputData[index + 1] = 1;
        }
        if (field_raw_data[i] > 2) {
          inputData[index + 2] = 1;
        }
        if (field_raw_data[i] > 3) {
          inputData[index + 3] = (field_raw_data[i] - 3) / 2;
        }
      }

      if (field_raw_data[i] < 0) {
        let index = 4 * i + (24 * 4 + 2) + index_shift; //  24 * 4 + 2  shift by white data

        inputData[index] = 1;

        if (field_raw_data[i] < -1) {
          inputData[index + 1] = 1;
        }

        if (field_raw_data[i] < -2) {
          inputData[index + 2] = 1;
        }

        if (field_raw_data[i] < -3) {
          inputData[index + 3] = (-field_raw_data[i] - 3) / 2;
        }
      }
    }

    let field_data = field.getFieldDataForPlayer(player);
    let enemy_field_data = field.getFieldDataForPlayer(
      Player.getNextPlayer(player)
    );

    inputData[24 * 4 + index_shift] = field_data.max_line;
    inputData[24 * 4 + 1 + index_shift] =
      (15 - field.getChipsCounterForPlayer(player)) / 15;

    inputData[24 * 4 + 2 + 24 * 4 + index_shift] = enemy_field_data.max_line;
    inputData[24 * 4 + 2 + 24 * 4 + 1 + index_shift] =
      (15 - field.getChipsCounterForPlayer(Player.getNextPlayer(player))) / 15;

    let index = 198 + index_shift;

    inputData[index++] = field_data.sum_to_end / 360;
    inputData[index++] = field_data.has_block;

    inputData[index++] = field_data.min_pos / 24;
    inputData[index++] = field_data.max_pos / 24;

    inputData[index++] = field_data.max_dist / 5;

    inputData[index++] = field_data.block_start_1 / 24;
    inputData[index++] = field_data.block_end_1 / 24;
    inputData[index++] = field_data.block_start_2 / 24;
    inputData[index++] = field_data.block_end_2 / 24;

    inputData[index++] = enemy_field_data.sum_to_end / 360;
    inputData[index++] = enemy_field_data.has_block;

    inputData[index++] = enemy_field_data.min_pos / 24;
    inputData[index++] = enemy_field_data.max_pos / 24;

    inputData[index++] = enemy_field_data.max_dist / 5;

    inputData[index++] = enemy_field_data.block_start_1 / 24;
    inputData[index++] = enemy_field_data.block_end_1 / 24;
    inputData[index++] = enemy_field_data.block_start_2 / 24;
    inputData[index++] = enemy_field_data.block_end_2 / 24;
  }

  extractDataRow(player, base, moves) {
    let inputData = [];
    for (let i = 0; i < INPUT_SIZE; i++) inputData.push(0);

    this.extractFieldDataRow(inputData, 0, player, base);

    this.extractFieldDataRow(inputData, INPUT_SIZE / 2, player, moves.field);

    return inputData;
  }

  chooseBestIndex(outputData) {
    let best_index = 0;
    let best_estimate = Number.NEGATIVE_INFINITY;

    for (let index = 0; index < outputData.length; index++) {
      let estimate = outputData[index][0];

      if (estimate > best_estimate) {
        best_index = index;
        best_estimate = estimate;
      }
    }

    return best_index;
  }

  extractData(player, base, possible_moves) {
    let INPUT_BATCH_SIZE = possible_moves.length;
    let inputData = [];

    for (let index = 0; index < INPUT_BATCH_SIZE; index++) {
      inputData.push(this.extractDataRow(player, base, possible_moves[index]));
    }
    return inputData;
  }

  infer(inputData) {
    let ret = this.model.predict(
      tf.tensor2d(inputData, [inputData.length, 432])
    );
    return ret.arraySync();
  }

  chooseMovesByEsstimate(game, possible_moves) {
    let best_index = tf.tidy(() => {
      let inputData = this.extractData(
        game.current_player,
        game.field,
        possible_moves
      );

      let outputData = this.infer(inputData);

      return this.chooseBestIndex(outputData);
    });

    let best_moves = possible_moves[best_index];

    return best_moves;
  }

  chooseMoves(game, possible_moves) {
    if (game.field.canDoOutStep(game.current_player))
      return chooseThrowMoves(game, possible_moves);
    else return this.chooseMovesByEsstimate(game, possible_moves);
  }
}
/*****************************************************************************/
function sigmoid(x) {
  return 1 / (1 + Math.exp(-x));
}

export class MediumAIPlayer {
  constructor() {
    this.data = MediumAIPlayer_data;
  }

  getEsstimate(player, field) {
    let field_data = field.getDataForPlayer(player);

    let full_field_data = field.getFieldDataForPlayer(player);

    let estimate = 0;

    for (let j = 0; j < this.data.HIDE_LEN; j++) {
      let part_sum = 0;
      let dist_sum = 0;

      for (let i = 0; i < this.data.INPUT_LEN; i++) {
        let data_i = field_data[i] / 15;

        part_sum += this.data.W[i][j] * data_i;

        dist_sum += Math.pow(
          (data_i - this.data.S[i][j]) * this.data.S_M[i][j],
          2
        );
      }

      let high_part = sigmoid(part_sum + this.data.W0[j]);
      let low_part = Math.exp((-this.data.A[j] * dist_sum) / 2);

      estimate += this.data.B[j] * high_part * low_part;
    }

    estimate =
      estimate +
      this.data.W_AI[8] *
        sigmoid(
          this.data.W_AI[0] * full_field_data.has_block +
            (this.data.W_AI[2] * full_field_data.min_pos) / 24 +
            (this.data.W_AI[3] * full_field_data.max_pos) / 24
        ) +
      this.data.W_AI[8] *
        sigmoid((this.data.W_AI[1] * full_field_data.max_dist) / 5) +
      this.data.W_AI[10] *
        sigmoid(
          (this.data.W_AI[4] * full_field_data.block_start_1) / 24 +
            (this.data.W_AI[5] * full_field_data.block_end_1) / 24
        ) +
      this.data.W_AI[11] *
        sigmoid(
          (this.data.W_AI[6] * full_field_data.block_start_2) / 24 +
            (this.data.W_AI[7] * full_field_data.block_end_2) / 24
        );

    return estimate;
  }

  chooseMovesByEsstimate(game, possible_moves) {
    let outputData = [];

    for (let moves of possible_moves) {
      let estimate = this.getEsstimate(game.current_player, moves.field);

      outputData.push([estimate]);
    }

    let best_index = this.chooseBestIndex(outputData);

    let best_moves = possible_moves[best_index];

    return best_moves;
  }

  chooseBestIndex(outputData) {
    let best_index = 0;
    let best_estimate = Number.NEGATIVE_INFINITY;

    for (let index = 0; index < outputData.length; index++) {
      let estimate = outputData[index][0];

      if (estimate > best_estimate) {
        best_index = index;
        best_estimate = estimate;
      }
    }

    return best_index;
  }

  chooseMoves(game, possible_moves) {
    if (game.field.canDoOutStep(game.current_player))
      return chooseThrowMoves(game, possible_moves);
    else return this.chooseMovesByEsstimate(game, possible_moves);
  }
}
/*****************************************************************************/
export class EasyAIPlayer {
  constructor() {
    this.w = [
      0.9705574703216554, 0.9805882167816165, 0.99127682685852,
      0.9957925701141359, 0.9829144191741936, 0.9857761096954342,
      1.0005077171325676, 0.997759599685669, 1.0124825000762943,
      0.9906270027160639, 0.9760481166839606, 0.9849176597595215,
      1.0013169860839843, 1.0005735206604003, 1.0000116157531747,
      0.9962803077697754, 1.0143678283691404, 0.999853248596192,
      1.0314336013793937, 1.0333253955841057, 1.0286037063598632,
      1.0115188121795646, 1.0075487251281736, 0.9923785495758056,
    ];
  }

  chooseBestIndex(outputData) {
    let best_index = 0;
    let best_estimate = Number.NEGATIVE_INFINITY;

    for (let index = 0; index < outputData.length; index++) {
      let estimate = outputData[index][0];

      if (estimate > best_estimate) {
        best_index = index;
        best_estimate = estimate;
      }
    }

    return best_index;
  }

  chooseMovesByEsstimate(game, possible_moves) {
    let outputData = [];

    for (let moves of possible_moves) {
      let field_data = moves.field.getDataForPlayer(game.current_player);

      let estimate = 0;
      for (let i = 0; i < FIELD_SIZE; i++) {
        if (field_data[i] > 0) {
          estimate += this.w[i];
        }
      }
      outputData.push([estimate]);
    }

    let best_index = this.chooseBestIndex(outputData);

    let best_moves = possible_moves[best_index];

    return best_moves;
  }

  chooseMoves(game, possible_moves) {
    if (game.field.canDoOutStep(game.current_player))
      return chooseThrowMoves(game, possible_moves);
    else return this.chooseMovesByEsstimate(game, possible_moves);
  }
}

// export const hard_ai_player = new HardAIPlayer();
// export const medium_ai_player = new MediumAIPlayer();
// export const easy_ai_player = new EasyAIPlayer();

export const hard_ai_player = new HardAIPlayer(model_uhard);
export const medium_ai_player = new HardAIPlayer(model_hard);
export const easy_ai_player = new MediumAIPlayer();

export const NardsPlayers = {
  HUMAN: 0,
  HARD: hard_ai_player,
  MEDIUM: medium_ai_player,
  EASY: easy_ai_player,
};

export default NardsPlayers;
