local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")

local RANKS = "12345678"
local FILES = "abcdefgh"
local PieceSymbols = "PpRrNnBbQqKk"
local UnicodePieces = {"♙", "♟", "♖", "♜", "♘", "♞", "♗", "♝", "♕", "♛", "♔", "♚"}
local StartingFen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"

--
-- Lua 5.2 Compat
--

if not table.create then
	function table.create(n, v)
		local result = {}
		for i=1,n do result[i] = v end
		return result 
	end
end

if not table.move then
	function table.move(a, from, to, start, target)
		local dx = start - from
		for i=from,to do
			target[i+dx] = a[i]
		end
	end
end


--
-- Utils
-- 

local function square(s)
	return RANKS:find(s:sub(2,2)) * 8 + FILES:find(s:sub(1,1)) - 9
end

local function squareName(n)
	local file = n % 8
	local rank = (n-file)/8
	return FILES:sub(file+1,file+1) .. RANKS:sub(rank+1,rank+1)
end

local function moveName(v )  
	local from = bit32.extract(v, 6, 6)
	local to = bit32.extract(v, 0, 6)
	local piece = bit32.extract(v, 20, 4)
	local captured = bit32.extract(v, 25, 4)

	local move = PieceSymbols:sub(piece,piece) .. ' ' .. squareName(from) .. (captured ~= 0 and 'x' or '-') .. squareName(to)
 
	if bit32.extract(v,14) == 1 then
		if to > from then
			return "O-O"
		else
			return "O-O-O"
		end
	end

	local promote = bit32.extract(v,15,4)
	if promote ~= 0 then
		move = move .. "=" .. PieceSymbols:sub(promote,promote)
	end
	return move
end

local function ucimove(m)
	local mm = squareName(bit32.extract(m, 6, 6)) .. squareName(bit32.extract(m, 0, 6))
	local promote = bit32.extract(m,15,4)
	if promote > 0 then
		mm = mm .. PieceSymbols:sub(promote,promote):lower()
	end
	return mm
end

local _utils = {squareName, moveName}

--
-- Bitboards
-- 

local Bitboard = {}


function Bitboard:toString()  
	local out = {}
	
	local src = self.h
	for x=7,0,-1 do
		table.insert(out, RANKS:sub(x+1,x+1))
		table.insert(out, " ")
		local bit = bit32.lshift(1,(x%4) * 8)
		for x=0,7 do
			if bit32.band(src, bit) ~= 0 then
				table.insert(out, "x ")
			else
				table.insert(out, "- ")
			end
			bit = bit32.lshift(bit, 1)
		end
		if x == 4 then
			src = self.l
		end
		table.insert(out, "\n")
	end
	table.insert(out, '  ' .. FILES:gsub('.', '%1 ') .. '\n')
	table.insert(out, '#: ' .. self:popcnt() .. "\tl:" .. self.l .. "\th:" .. self.h)
	return table.concat(out)
end


function Bitboard.from(l ,h )  
	return setmetatable({l=l, h=h}, Bitboard)
end

Bitboard.zero = Bitboard.from(0,0)
Bitboard.full = Bitboard.from(0xFFFFFFFF, 0xFFFFFFFF)

local Rank1 = Bitboard.from(0x000000FF, 0)
local Rank3 = Bitboard.from(0x00FF0000, 0)
local Rank6 = Bitboard.from(0, 0x0000FF00)
local Rank8 = Bitboard.from(0, 0xFF000000)
local FileA = Bitboard.from(0x01010101, 0x01010101)
local FileB = Bitboard.from(0x02020202, 0x02020202)
local FileC = Bitboard.from(0x04040404, 0x04040404)
local FileD = Bitboard.from(0x08080808, 0x08080808)
local FileE = Bitboard.from(0x10101010, 0x10101010)
local FileF = Bitboard.from(0x20202020, 0x20202020)
local FileG = Bitboard.from(0x40404040, 0x40404040)
local FileH = Bitboard.from(0x80808080, 0x80808080)

local _Files = {FileA, FileB, FileC, FileD, FileE, FileF, FileG, FileH}

-- These masks are filled out below for all files
local RightMasks = {FileH}
local LeftMasks = {FileA}



local function popcnt32(i)
	i = i - bit32.band(bit32.rshift(i,1), 0x55555555)
	i = bit32.band(i, 0x33333333) + bit32.band(bit32.rshift(i,2), 0x33333333)
	return bit32.rshift(bit32.band(i + bit32.rshift(i,4), 0x0F0F0F0F) * 0x01010101, 24)
end

function Bitboard:up()  
	return self:lshift(8)
end

function Bitboard:down()  
	return self:rshift(8)
end

function Bitboard:right()  
	return self:band(FileH:inverse()):lshift(1)
end

function Bitboard:left()  
	return self:band(FileA:inverse()):rshift(1)
end

function Bitboard:move(x,y)  
	local out = self

	if x < 0 then out = out:bandnot(RightMasks[-x]):lshift(-x) end
	if x > 0 then out = out:bandnot(LeftMasks[x]):rshift(x) end

	if y < 0 then out = out:rshift(-8 * y) end
	if y > 0 then out = out:lshift(8 * y) end
	return out
end


function Bitboard:popcnt()  
	return popcnt32(self.l) + popcnt32(self.h)
end

function Bitboard:band(other )  
	return Bitboard.from(bit32.band(self.l,other.l), bit32.band(self.h, other.h))
end

function Bitboard:bandnot(other )  
	return Bitboard.from(bit32.band(self.l,bit32.bnot(other.l)), bit32.band(self.h, bit32.bnot(other.h)))
end

function Bitboard:bandempty(other )  
	return bit32.band(self.l,other.l) == 0 and bit32.band(self.h, other.h) == 0
end

function Bitboard:bor(other )  
	return Bitboard.from(bit32.bor(self.l,other.l), bit32.bor(self.h, other.h))
end

function Bitboard:bxor(other )  
	return Bitboard.from(bit32.bxor(self.l,other.l), bit32.bxor(self.h, other.h))
end

function Bitboard:inverse()  
	return Bitboard.from(bit32.bxor(self.l,0xFFFFFFFF), bit32.bxor(self.h, 0xFFFFFFFF))
end

function Bitboard:empty()  
	return self.h == 0 and self.l == 0
end

if not bit32.countrz then
	local function ctz(v)
		if v == 0 then return 32 end
		local offset = 0
		while bit32.extract(v, offset) == 0 do
			offset = offset + 1
		end
		return offset
	end
	function Bitboard:ctz()
		local result = ctz(self.l)
		if result == 32 then
			return ctz(self.h) + 32
		else
			return result
		end
	end
	function Bitboard:ctzafter(start)
		start = start + 1
		if start < 32 then
			for i=start,31 do
				if bit32.extract(self.l, i) == 1 then return i end
			end
		end
		for i=math.max(32,start),63 do
			if bit32.extract(self.h, i-32) == 1 then return i end
		end
		return 64
	end
else
	function Bitboard:ctz()
		local result = bit32.countrz(self.l)
		if result == 32 then
			return bit32.countrz(self.h) + 32
		else
			return result
		end
	end
	function Bitboard:ctzafter(start)
		local masked = self:band(Bitboard.full:lshift(start+1))
		return masked:ctz()
	end
end


function Bitboard:lshift(amt)
	assert(amt >= 0)
	if amt == 0 then return self end

	if amt > 31 then
		return Bitboard.from(0, bit32.lshift(self.l, amt-32))
	end

	local l = bit32.lshift(self.l, amt)
	local h = bit32.bor(
		bit32.lshift(self.h, amt),
		bit32.extract(self.l, 32-amt, amt)
	)
	return Bitboard.from(l, h)
end

function Bitboard:rshift(amt)
	assert(amt >= 0)
	if amt == 0 then return self end
	local h = bit32.rshift(self.h, amt)
	local l = bit32.bor(
		bit32.rshift(self.l, amt),
		bit32.lshift(bit32.extract(self.h, 0, amt), 32-amt)
	)
	return Bitboard.from(l, h)
end

function Bitboard:index(i)
	if i > 31 then
		return bit32.extract(self.h, i - 32)
	else
		return bit32.extract(self.l, i)
	end
end

function Bitboard:set(i , v)
	if i > 31 then
		return Bitboard.from(self.l, bit32.replace(self.h, v, i - 32))
	else
		return Bitboard.from(bit32.replace(self.l, v, i), self.h)
	end	
end

function Bitboard:isolate(i)
	return self:band(Bitboard.some(i))
end

function Bitboard.some(idx )  
	return Bitboard.zero:set(idx, 1)
end

Bitboard.__index = Bitboard
Bitboard.__tostring = Bitboard.toString

for i=2,8 do
	RightMasks[i] = RightMasks[i-1]:rshift(1):bor(FileH)
	LeftMasks[i] = LeftMasks[i-1]:lshift(1):bor(FileA)
end
--
-- Board
--

local Board = {}


function Board.new()  
	local boards = table.create(12, Bitboard.zero)
	boards.ocupied = Bitboard.zero
	boards.white = Bitboard.zero
	boards.black = Bitboard.zero
	boards.unocupied = Bitboard.full
	boards.ep = Bitboard.zero
	boards.castle = Bitboard.zero
	boards.toMove = 1
	boards.hm = 0
	boards.moves = 0
	boards.material = 0

	return setmetatable(boards, Board)
end

function Board.fromFen(fen )
	local b = Board.new()
	local i = 0
	local rank = 7
	local file = 0

	while true do
		i = i + 1
		local p = fen:sub(i,i)
		if p == '/' then
			rank = rank - 1
			file = 0
		elseif tonumber(p) ~= nil then
			file = file + tonumber(p)
		else
			local pidx = PieceSymbols:find(p)
			if pidx == nil then break end
			b[pidx] = b[pidx]:set(rank*8+file, 1)
			file = file + 1
		end
	end


	local move, castle, ep, hm, m = string.match(fen, "^ ([bw]) ([KQkq-]*) ([a-h-][0-9]?) (%d*) (%d*)", i)
	if move == nil then print(fen:sub(i)) end
	b.toMove = move == 'w' and 1 or 2

	if ep ~= "-" then
		b.ep = Bitboard.some(square(ep))
	end

	if castle ~= "-" then
		local oo = Bitboard.zero
		if castle:find("K") then
			oo = oo:set(7, 1)
		end
		if castle:find("Q") then
			oo = oo:set(0, 1)
		end
		if castle:find("k") then
			oo = oo:set(63, 1)
		end
		if castle:find("q") then
			oo = oo:set(56, 1)
		end

		b.castle = oo
	end

	b.hm = hm
	b.moves = m

	b:updateCache()
	return b

end

function Board:index(idx )  
	if self.white:index(idx) == 1 then
		for p=1,12,2 do
			if self[p]:index(idx) == 1 then
				return p
			end
		end
	else
		for p=2,12,2 do
			if self[p]:index(idx) == 1 then
				return p
			end
		end
	end
	
	return 0
end

function Board:updateCache()
	for i=1,11,2 do
		self.white = self.white:bor(self[i])
		self.black = self.black:bor(self[i+1])
	end

	self.ocupied = self.black:bor(self.white)
	self.unocupied = self.ocupied:inverse()
	self.material =
		100*self[1]:popcnt() - 100*self[2]:popcnt() +
		500*self[3]:popcnt() - 500*self[4]:popcnt() +
		300*self[5]:popcnt() - 300*self[6]:popcnt() +
		300*self[7]:popcnt() - 300*self[8]:popcnt() +
		900*self[9]:popcnt() - 900*self[10]:popcnt()

end

function Board:fen()
	local out = {}
	local s = 0
	local idx = 56
	for i=0,63 do
		if i % 8 == 0 and i > 0 then
			idx = idx - 16
			if s > 0 then
				table.insert(out, '' .. s)
				s = 0
			end
			table.insert(out, '/')
		end
		local p = self:index(idx)
		if p == 0 then
			s = s + 1
		else
			if s > 0 then
				table.insert(out, '' .. s)
				s = 0
			end
			table.insert(out, PieceSymbols:sub(p,p))
		end

		idx = idx + 1
	end
	if s > 0 then
		table.insert(out, '' .. s)
	end

	table.insert(out, self.toMove == 1 and ' w ' or ' b ')
	if self.castle:empty() then
		table.insert(out, '-')
	else
		if self.castle:index(7) == 1 then table.insert(out, 'K') end
		if self.castle:index(0) == 1 then table.insert(out, 'Q') end
		if self.castle:index(63) == 1 then table.insert(out, 'k') end
		if self.castle:index(56) == 1 then table.insert(out, 'q') end
	end

	table.insert(out, ' ')
	if self.ep:empty() then
		table.insert(out, '-')
	else
		table.insert(out, squareName(self.ep:ctz()))
	end

	table.insert(out, ' ' .. self.hm)
	table.insert(out, ' ' .. self.moves)

	return table.concat(out)
end

function Board:pmoves(idx)
	return self:generate(idx)
end

function Board:pcaptures(idx)
	return self:generate(idx):band(self.ocupied)
end

local ROOK_SLIDES = {{1,0}, {-1,0}, {0,1}, {0,-1}}
local BISHOP_SLIDES = {{1,1}, {-1,1}, {1,-1}, {-1,-1}}
local QUEEN_SLIDES = {{1,0}, {-1,0}, {0,1}, {0,-1}, {1,1}, {-1,1}, {1,-1}, {-1,-1}}
local KNIGHT_MOVES = {{2,1}, {2,-1}, {-2,1}, {-2,-1}, {1,2}, {1,-2}, {-1,2}, {-1,-2}}

function Board:generate(idx)
	local piece = self:index(idx)
	local r = Bitboard.some(idx)
	local out = Bitboard.zero
	local type = bit32.rshift(piece - 1, 1)
	local cancapture = piece % 2 == 1 and self.black or self.white

	if piece == 0 then return Bitboard.zero end

	if type == 0 then
		-- Pawn
		local d = -(piece*2 - 3)
		local movetwo = piece == 1 and Rank3 or Rank6

		out = out:bor(r:move(0,d):band(self.unocupied))
		out = out:bor(out:band(movetwo):move(0,d):band(self.unocupied))

		local captures = r:move(0,d)
		captures = captures:right():bor(captures:left())

		if not captures:bandempty(self.ep) then
			out = out:bor(self.ep)
		end

		captures = captures:band(cancapture)
		out = out:bor(captures)

		return out
	elseif type == 5 then
		-- King
		for x=-1,1,1 do
			for y = -1,1,1 do
				local w = r:move(x,y)
				if self.ocupied:bandempty(w) then
					out = out:bor(w)
				else
					if not cancapture:bandempty(w) then
						out = out:bor(w)
					end
				end
			end
		end
	elseif type == 2 then
		-- Knight
		for _,j in ipairs(KNIGHT_MOVES) do
			local w = r:move(j[1],j[2])
			
			if self.ocupied:bandempty(w) then
				out = out:bor(w)
			else
				if not cancapture:bandempty(w) then
					out = out:bor(w)
				end
			end
		end
	else
		-- Sliders (Rook, Bishop, Queen)
		local slides
		if type == 1 then
			slides = ROOK_SLIDES
		elseif type == 3 then
			slides = BISHOP_SLIDES
		else
			slides = QUEEN_SLIDES
		end

		for _, op in ipairs(slides) do
			local w = r
			for i=1,7 do
				w = w:move(op[1], op[2])
				if w:empty() then break end

				if self.ocupied:bandempty(w) then
					out = out:bor(w)
				else
					if not cancapture:bandempty(w) then
						out = out:bor(w)
					end
					break
				end
			end
		end
	end


	return out
end

-- 0-5 - From Square
-- 6-11 - To Square
-- 12 - is Check
-- 13 - Is EnPassent
-- 14 - Is Castle
-- 15-19 - Promotion Piece
-- 20-24 - Moved Pice
-- 25-29 - Captured Piece


function Board:toString(mark )  
	local out = {}
	for x=8,1,-1 do
		table.insert(out, RANKS:sub(x,x) .. " ")

		for y=1,8 do
			local n = 8*x+y-9
			local i = self:index(n)
			if i == 0 then
				table.insert(out, '-')
			else
				-- out = out .. PieceSymbols:sub(i,i)
				table.insert(out, UnicodePieces[i])
			end
			if mark ~= nil and mark:index(n) ~= 0 then
				table.insert(out, ')')
			elseif mark ~= nil and n < 63 and y < 8 and mark:index(n+1) ~= 0 then
				table.insert(out, '(')
			else
				table.insert(out, ' ')
			end
		end
		
		table.insert(out, "\n")
	end
	table.insert(out, '  '  .. FILES:gsub('.', '%1 ') .. '\n')
	table.insert(out, (self.toMove == 1 and "White" or "Black") .. ' e:' .. (self.material/100) .. "\n")
	return table.concat(out)
end

function Board:moveList()  
	local tm = self.toMove == 1 and self.white or self.black
	local castle_rank = self.toMove == 1 and Rank1 or Rank8
	local out = {}
	local function emit(id)
		if not self:applyMove(id):illegalyChecked() then
			table.insert(out, id)
		end	
	end

	local cr = tm:band(self.castle):band(castle_rank)
	if not cr:empty() then
		local p = self.toMove == 1 and 11 or 12
		local tcolor = self.toMove == 1 and self.black or self.white
		local kidx = self[p]:ctz()


		local castle = bit32.replace(0, p, 20, 4)
		castle = bit32.replace(castle, kidx, 6, 6)
		castle = bit32.replace(castle, 1, 14)
		

		local mustbeemptyl = LeftMasks[4]:bxor(FileA):band(castle_rank)
		local cantbethreatened = FileD:bor(FileC):band(castle_rank):bor(self[p])
		if
			not cr:bandempty(FileA) and
			mustbeemptyl:bandempty(self.ocupied) and
			not self:isSquareThreatened(cantbethreatened, tcolor)
		then
			emit(bit32.replace(castle, kidx - 2, 0, 6))
		end
			

		local mustbeemptyr = RightMasks[3]:bxor(FileH):band(castle_rank)
		if
			not cr:bandempty(FileH) and
			mustbeemptyr:bandempty(self.ocupied) and
			not self:isSquareThreatened(mustbeemptyr:bor(self[p]), tcolor)
		then
			emit(bit32.replace(castle, kidx + 2, 0, 6))
		end
	end

	local sq = tm:ctz()
	repeat
		local p = self:index(sq)
		local moves = self:pmoves(sq)

		while not moves:empty() do
			local m = moves:ctz()
			moves = moves:set(m, 0)
			local id = bit32.replace(m, sq, 6, 6)
			id = bit32.replace(id, p, 20, 4)
			local mbb = Bitboard.some(m)
			if not self.ocupied:bandempty(mbb) then
				id = bit32.replace(id, self:index(m), 25, 4)
			end

			-- Check if pawn needs to be promoted
			if p == 1 and m >= 8*7 then
				for i=3,9,2 do
					emit(bit32.replace(id, i, 15, 4))
				end
			elseif p == 2 and m < 8 then
				for i=4,10,2 do
					emit(bit32.replace(id, i, 15, 4))
				end
			else
				emit(id)
			end	
		end
		sq = tm:ctzafter(sq)
	until sq == 64
	return out
end

function Board:illegalyChecked()  
	local target = self.toMove == 1 and self[PieceSymbols:find("k")] or self[PieceSymbols:find("K")]
	return self:isSquareThreatened(target, self.toMove == 1 and self.white or self.black)
end

function Board:isSquareThreatened(target , color )  
	local tm = color
	local sq = tm:ctz()
	repeat
		local moves = self:pmoves(sq)
		if not moves:bandempty(target) then
			return true
		end
		sq = color:ctzafter(sq)
	until sq == 64
	return false
end

function Board:perft(depth )  
	if depth == 0 then return 1 end
	if depth == 1 then 
		return #self:moveList()
	end
	local result = 0
	for k,m in ipairs(self:moveList()) do
		local c = self:applyMove(m):perft(depth - 1)
		if c == 0 then
			-- Perft only counts leaf nodes at target depth
			-- result = result + 1
		else
			result = result + c
		end
	end
	return result
end


function Board:applyMove(move )  
	local out = Board.new()
	table.move(self, 1, 12, 1, out)
	local from = bit32.extract(move, 6, 6)
	local to = bit32.extract(move, 0, 6)
	local promote = bit32.extract(move, 15, 4)
	local piece = self:index(from)
	local captured = self:index(to)
	local tom = Bitboard.some(to)
	local isCastle = bit32.extract(move, 14)

	if piece % 2 == 0 then
		out.moves = self.moves + 1
	end

	if captured == 1 or piece < 3 then
		out.hm = 0
	else
		out.hm = self.hm + 1
	end
	out.castle = self.castle
	out.toMove = self.toMove == 1 and 2 or 1

	if isCastle == 1 then
		local rank = piece == 11 and Rank1 or Rank8
		local colorOffset = piece - 11

		out[3 + colorOffset] = out[3 + colorOffset]:bandnot(from < to and FileH or FileA) 
		out[3 + colorOffset] = out[3 + colorOffset]:bor((from < to and FileF or FileD):band(rank)) 

		out[piece] = (from < to and FileG or FileC):band(rank)
		out.castle = out.castle:bandnot(rank)
		out:updateCache()
		return out
	end

	if piece < 3 then
		local dist = math.abs(to - from)
		-- Pawn moved two squares, set ep square
		if dist == 16 then
			out.ep = Bitboard.some((from + to) / 2)
		end

		-- Remove enpasent capture
		if not tom:bandempty(self.ep) then
			if piece == 1 then
				out[2] = out[2]:bandnot(self.ep:down())
			end
			if piece == 2 then
				out[1] = out[1]:bandnot(self.ep:up())
			end
		end
	end

	if piece == 3 or piece == 4 then
		out.castle = out.castle:set(from, 0)
	end

	if piece > 10 then
		local rank = piece == 11 and Rank1 or Rank8
		out.castle = out.castle:bandnot(rank)
	end

	out[piece] = out[piece]:set(from, 0)
	if promote == 0 then
		out[piece] = out[piece]:set(to, 1)
	else
		out[promote] = out[promote]:set(to, 1)
	end
	if captured ~= 0 then
		out[captured] = out[captured]:set(to, 0)
	end

	out:updateCache()
	return out
end

Board.__index = Board
Board.__tostring = Board.toString
--
-- Main
--

local failures = 0
local function test(fen, ply, target)
	local b = Board.fromFen(fen)
	if b:fen() ~= fen then
		print("FEN MISMATCH", fen, b:fen())
		failures = failures + 1
		return
	end

	local found = b:perft(ply)
	if found ~= target then
		print(fen, "Found", found, "target", target)
		failures = failures + 1
		for k,v in pairs(b:moveList()) do
			print(ucimove(v) .. ': ' .. (ply > 1 and b:applyMove(v):perft(ply-1) or '1'))
		end
		--error("Test Failure")
	else
		print("OK", found, fen)
	end
end

-- From https://www.chessprogramming.org/Perft_Results
-- If interpreter, computers, or algorithm gets too fast
-- feel free to go deeper

local testCases = {}
local function addTest(...) table.insert(testCases, {...}) end

addTest(StartingFen, 2, 400)
addTest("r3k2r/p1ppqpb1/bn2pnp1/3PN3/1p2P3/2N2Q1p/PPPBBPPP/R3K2R w KQkq - 0 0", 1, 48)
addTest("8/2p5/3p4/KP5r/1R3p1k/8/4P1P1/8 w - - 0 0", 2, 191)
addTest("r3k2r/Pppp1ppp/1b3nbN/nP6/BBP1P3/q4N2/Pp1P2PP/R2Q1RK1 w kq - 0 1", 2, 264)
addTest("rnbq1k1r/pp1Pbppp/2p5/8/2B5/8/PPP1NnPP/RNBQK2R w KQ - 1 8", 1, 44)
addTest("r4rk1/1pp1qppp/p1np1n2/2b1p1B1/2B1P1b1/P1NP1N2/1PP1QPPP/R4RK1 w - - 0 10", 1, 46)


local function chess()
	for k,v in ipairs(testCases) do
		test(v[1],v[2],v[3])
	end
end

bench.runCode(chess, "chess")