1
0
Fork 0
mirror of https://github.com/luau-lang/luau.git synced 2025-03-05 03:31:41 +00:00
luau/bench/tests/chess.lua
Vyacheslav Egorov aafea36235
Fixed the backwards compatible benchmark support library require ()
Previous benchmark require fix wasn't actually working correctly for the
old style require (or running in Lua).
2023-12-04 12:48:31 -08:00

860 lines
19 KiB
Lua

local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")
local RANKS = "12345678"
local FILES = "abcdefgh"
local PieceSymbols = "PpRrNnBbQqKk"
local UnicodePieces = {"", "", "", "", "", "", "", "", "", "", "", ""}
local StartingFen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
--
-- Lua 5.2 Compat
--
if not table.create then
function table.create(n, v)
local result = {}
for i=1,n do result[i] = v end
return result
end
end
if not table.move then
function table.move(a, from, to, start, target)
local dx = start - from
for i=from,to do
target[i+dx] = a[i]
end
end
end
--
-- Utils
--
local function square(s)
return RANKS:find(s:sub(2,2)) * 8 + FILES:find(s:sub(1,1)) - 9
end
local function squareName(n)
local file = n % 8
local rank = (n-file)/8
return FILES:sub(file+1,file+1) .. RANKS:sub(rank+1,rank+1)
end
local function moveName(v )
local from = bit32.extract(v, 6, 6)
local to = bit32.extract(v, 0, 6)
local piece = bit32.extract(v, 20, 4)
local captured = bit32.extract(v, 25, 4)
local move = PieceSymbols:sub(piece,piece) .. ' ' .. squareName(from) .. (captured ~= 0 and 'x' or '-') .. squareName(to)
if bit32.extract(v,14) == 1 then
if to > from then
return "O-O"
else
return "O-O-O"
end
end
local promote = bit32.extract(v,15,4)
if promote ~= 0 then
move = move .. "=" .. PieceSymbols:sub(promote,promote)
end
return move
end
local function ucimove(m)
local mm = squareName(bit32.extract(m, 6, 6)) .. squareName(bit32.extract(m, 0, 6))
local promote = bit32.extract(m,15,4)
if promote > 0 then
mm = mm .. PieceSymbols:sub(promote,promote):lower()
end
return mm
end
local _utils = {squareName, moveName}
--
-- Bitboards
--
local Bitboard = {}
function Bitboard:toString()
local out = {}
local src = self.h
for x=7,0,-1 do
table.insert(out, RANKS:sub(x+1,x+1))
table.insert(out, " ")
local bit = bit32.lshift(1,(x%4) * 8)
for x=0,7 do
if bit32.band(src, bit) ~= 0 then
table.insert(out, "x ")
else
table.insert(out, "- ")
end
bit = bit32.lshift(bit, 1)
end
if x == 4 then
src = self.l
end
table.insert(out, "\n")
end
table.insert(out, ' ' .. FILES:gsub('.', '%1 ') .. '\n')
table.insert(out, '#: ' .. self:popcnt() .. "\tl:" .. self.l .. "\th:" .. self.h)
return table.concat(out)
end
function Bitboard.from(l ,h )
return setmetatable({l=l, h=h}, Bitboard)
end
Bitboard.zero = Bitboard.from(0,0)
Bitboard.full = Bitboard.from(0xFFFFFFFF, 0xFFFFFFFF)
local Rank1 = Bitboard.from(0x000000FF, 0)
local Rank3 = Bitboard.from(0x00FF0000, 0)
local Rank6 = Bitboard.from(0, 0x0000FF00)
local Rank8 = Bitboard.from(0, 0xFF000000)
local FileA = Bitboard.from(0x01010101, 0x01010101)
local FileB = Bitboard.from(0x02020202, 0x02020202)
local FileC = Bitboard.from(0x04040404, 0x04040404)
local FileD = Bitboard.from(0x08080808, 0x08080808)
local FileE = Bitboard.from(0x10101010, 0x10101010)
local FileF = Bitboard.from(0x20202020, 0x20202020)
local FileG = Bitboard.from(0x40404040, 0x40404040)
local FileH = Bitboard.from(0x80808080, 0x80808080)
local _Files = {FileA, FileB, FileC, FileD, FileE, FileF, FileG, FileH}
-- These masks are filled out below for all files
local RightMasks = {FileH}
local LeftMasks = {FileA}
local function popcnt32(i)
i = i - bit32.band(bit32.rshift(i,1), 0x55555555)
i = bit32.band(i, 0x33333333) + bit32.band(bit32.rshift(i,2), 0x33333333)
return bit32.rshift(bit32.band(i + bit32.rshift(i,4), 0x0F0F0F0F) * 0x01010101, 24)
end
function Bitboard:up()
return self:lshift(8)
end
function Bitboard:down()
return self:rshift(8)
end
function Bitboard:right()
return self:band(FileH:inverse()):lshift(1)
end
function Bitboard:left()
return self:band(FileA:inverse()):rshift(1)
end
function Bitboard:move(x,y)
local out = self
if x < 0 then out = out:bandnot(RightMasks[-x]):lshift(-x) end
if x > 0 then out = out:bandnot(LeftMasks[x]):rshift(x) end
if y < 0 then out = out:rshift(-8 * y) end
if y > 0 then out = out:lshift(8 * y) end
return out
end
function Bitboard:popcnt()
return popcnt32(self.l) + popcnt32(self.h)
end
function Bitboard:band(other )
return Bitboard.from(bit32.band(self.l,other.l), bit32.band(self.h, other.h))
end
function Bitboard:bandnot(other )
return Bitboard.from(bit32.band(self.l,bit32.bnot(other.l)), bit32.band(self.h, bit32.bnot(other.h)))
end
function Bitboard:bandempty(other )
return bit32.band(self.l,other.l) == 0 and bit32.band(self.h, other.h) == 0
end
function Bitboard:bor(other )
return Bitboard.from(bit32.bor(self.l,other.l), bit32.bor(self.h, other.h))
end
function Bitboard:bxor(other )
return Bitboard.from(bit32.bxor(self.l,other.l), bit32.bxor(self.h, other.h))
end
function Bitboard:inverse()
return Bitboard.from(bit32.bxor(self.l,0xFFFFFFFF), bit32.bxor(self.h, 0xFFFFFFFF))
end
function Bitboard:empty()
return self.h == 0 and self.l == 0
end
if not bit32.countrz then
local function ctz(v)
if v == 0 then return 32 end
local offset = 0
while bit32.extract(v, offset) == 0 do
offset = offset + 1
end
return offset
end
function Bitboard:ctz()
local result = ctz(self.l)
if result == 32 then
return ctz(self.h) + 32
else
return result
end
end
function Bitboard:ctzafter(start)
start = start + 1
if start < 32 then
for i=start,31 do
if bit32.extract(self.l, i) == 1 then return i end
end
end
for i=math.max(32,start),63 do
if bit32.extract(self.h, i-32) == 1 then return i end
end
return 64
end
else
function Bitboard:ctz()
local result = bit32.countrz(self.l)
if result == 32 then
return bit32.countrz(self.h) + 32
else
return result
end
end
function Bitboard:ctzafter(start)
local masked = self:band(Bitboard.full:lshift(start+1))
return masked:ctz()
end
end
function Bitboard:lshift(amt)
assert(amt >= 0)
if amt == 0 then return self end
if amt > 31 then
return Bitboard.from(0, bit32.lshift(self.l, amt-32))
end
local l = bit32.lshift(self.l, amt)
local h = bit32.bor(
bit32.lshift(self.h, amt),
bit32.extract(self.l, 32-amt, amt)
)
return Bitboard.from(l, h)
end
function Bitboard:rshift(amt)
assert(amt >= 0)
if amt == 0 then return self end
local h = bit32.rshift(self.h, amt)
local l = bit32.bor(
bit32.rshift(self.l, amt),
bit32.lshift(bit32.extract(self.h, 0, amt), 32-amt)
)
return Bitboard.from(l, h)
end
function Bitboard:index(i)
if i > 31 then
return bit32.extract(self.h, i - 32)
else
return bit32.extract(self.l, i)
end
end
function Bitboard:set(i , v)
if i > 31 then
return Bitboard.from(self.l, bit32.replace(self.h, v, i - 32))
else
return Bitboard.from(bit32.replace(self.l, v, i), self.h)
end
end
function Bitboard:isolate(i)
return self:band(Bitboard.some(i))
end
function Bitboard.some(idx )
return Bitboard.zero:set(idx, 1)
end
Bitboard.__index = Bitboard
Bitboard.__tostring = Bitboard.toString
for i=2,8 do
RightMasks[i] = RightMasks[i-1]:rshift(1):bor(FileH)
LeftMasks[i] = LeftMasks[i-1]:lshift(1):bor(FileA)
end
--
-- Board
--
local Board = {}
function Board.new()
local boards = table.create(12, Bitboard.zero)
boards.ocupied = Bitboard.zero
boards.white = Bitboard.zero
boards.black = Bitboard.zero
boards.unocupied = Bitboard.full
boards.ep = Bitboard.zero
boards.castle = Bitboard.zero
boards.toMove = 1
boards.hm = 0
boards.moves = 0
boards.material = 0
return setmetatable(boards, Board)
end
function Board.fromFen(fen )
local b = Board.new()
local i = 0
local rank = 7
local file = 0
while true do
i = i + 1
local p = fen:sub(i,i)
if p == '/' then
rank = rank - 1
file = 0
elseif tonumber(p) ~= nil then
file = file + tonumber(p)
else
local pidx = PieceSymbols:find(p)
if pidx == nil then break end
b[pidx] = b[pidx]:set(rank*8+file, 1)
file = file + 1
end
end
local move, castle, ep, hm, m = string.match(fen, "^ ([bw]) ([KQkq-]*) ([a-h-][0-9]?) (%d*) (%d*)", i)
if move == nil then print(fen:sub(i)) end
b.toMove = move == 'w' and 1 or 2
if ep ~= "-" then
b.ep = Bitboard.some(square(ep))
end
if castle ~= "-" then
local oo = Bitboard.zero
if castle:find("K") then
oo = oo:set(7, 1)
end
if castle:find("Q") then
oo = oo:set(0, 1)
end
if castle:find("k") then
oo = oo:set(63, 1)
end
if castle:find("q") then
oo = oo:set(56, 1)
end
b.castle = oo
end
b.hm = hm
b.moves = m
b:updateCache()
return b
end
function Board:index(idx )
if self.white:index(idx) == 1 then
for p=1,12,2 do
if self[p]:index(idx) == 1 then
return p
end
end
else
for p=2,12,2 do
if self[p]:index(idx) == 1 then
return p
end
end
end
return 0
end
function Board:updateCache()
for i=1,11,2 do
self.white = self.white:bor(self[i])
self.black = self.black:bor(self[i+1])
end
self.ocupied = self.black:bor(self.white)
self.unocupied = self.ocupied:inverse()
self.material =
100*self[1]:popcnt() - 100*self[2]:popcnt() +
500*self[3]:popcnt() - 500*self[4]:popcnt() +
300*self[5]:popcnt() - 300*self[6]:popcnt() +
300*self[7]:popcnt() - 300*self[8]:popcnt() +
900*self[9]:popcnt() - 900*self[10]:popcnt()
end
function Board:fen()
local out = {}
local s = 0
local idx = 56
for i=0,63 do
if i % 8 == 0 and i > 0 then
idx = idx - 16
if s > 0 then
table.insert(out, '' .. s)
s = 0
end
table.insert(out, '/')
end
local p = self:index(idx)
if p == 0 then
s = s + 1
else
if s > 0 then
table.insert(out, '' .. s)
s = 0
end
table.insert(out, PieceSymbols:sub(p,p))
end
idx = idx + 1
end
if s > 0 then
table.insert(out, '' .. s)
end
table.insert(out, self.toMove == 1 and ' w ' or ' b ')
if self.castle:empty() then
table.insert(out, '-')
else
if self.castle:index(7) == 1 then table.insert(out, 'K') end
if self.castle:index(0) == 1 then table.insert(out, 'Q') end
if self.castle:index(63) == 1 then table.insert(out, 'k') end
if self.castle:index(56) == 1 then table.insert(out, 'q') end
end
table.insert(out, ' ')
if self.ep:empty() then
table.insert(out, '-')
else
table.insert(out, squareName(self.ep:ctz()))
end
table.insert(out, ' ' .. self.hm)
table.insert(out, ' ' .. self.moves)
return table.concat(out)
end
function Board:pmoves(idx)
return self:generate(idx)
end
function Board:pcaptures(idx)
return self:generate(idx):band(self.ocupied)
end
local ROOK_SLIDES = {{1,0}, {-1,0}, {0,1}, {0,-1}}
local BISHOP_SLIDES = {{1,1}, {-1,1}, {1,-1}, {-1,-1}}
local QUEEN_SLIDES = {{1,0}, {-1,0}, {0,1}, {0,-1}, {1,1}, {-1,1}, {1,-1}, {-1,-1}}
local KNIGHT_MOVES = {{2,1}, {2,-1}, {-2,1}, {-2,-1}, {1,2}, {1,-2}, {-1,2}, {-1,-2}}
function Board:generate(idx)
local piece = self:index(idx)
local r = Bitboard.some(idx)
local out = Bitboard.zero
local type = bit32.rshift(piece - 1, 1)
local cancapture = piece % 2 == 1 and self.black or self.white
if piece == 0 then return Bitboard.zero end
if type == 0 then
-- Pawn
local d = -(piece*2 - 3)
local movetwo = piece == 1 and Rank3 or Rank6
out = out:bor(r:move(0,d):band(self.unocupied))
out = out:bor(out:band(movetwo):move(0,d):band(self.unocupied))
local captures = r:move(0,d)
captures = captures:right():bor(captures:left())
if not captures:bandempty(self.ep) then
out = out:bor(self.ep)
end
captures = captures:band(cancapture)
out = out:bor(captures)
return out
elseif type == 5 then
-- King
for x=-1,1,1 do
for y = -1,1,1 do
local w = r:move(x,y)
if self.ocupied:bandempty(w) then
out = out:bor(w)
else
if not cancapture:bandempty(w) then
out = out:bor(w)
end
end
end
end
elseif type == 2 then
-- Knight
for _,j in ipairs(KNIGHT_MOVES) do
local w = r:move(j[1],j[2])
if self.ocupied:bandempty(w) then
out = out:bor(w)
else
if not cancapture:bandempty(w) then
out = out:bor(w)
end
end
end
else
-- Sliders (Rook, Bishop, Queen)
local slides
if type == 1 then
slides = ROOK_SLIDES
elseif type == 3 then
slides = BISHOP_SLIDES
else
slides = QUEEN_SLIDES
end
for _, op in ipairs(slides) do
local w = r
for i=1,7 do
w = w:move(op[1], op[2])
if w:empty() then break end
if self.ocupied:bandempty(w) then
out = out:bor(w)
else
if not cancapture:bandempty(w) then
out = out:bor(w)
end
break
end
end
end
end
return out
end
-- 0-5 - From Square
-- 6-11 - To Square
-- 12 - is Check
-- 13 - Is EnPassent
-- 14 - Is Castle
-- 15-19 - Promotion Piece
-- 20-24 - Moved Pice
-- 25-29 - Captured Piece
function Board:toString(mark )
local out = {}
for x=8,1,-1 do
table.insert(out, RANKS:sub(x,x) .. " ")
for y=1,8 do
local n = 8*x+y-9
local i = self:index(n)
if i == 0 then
table.insert(out, '-')
else
-- out = out .. PieceSymbols:sub(i,i)
table.insert(out, UnicodePieces[i])
end
if mark ~= nil and mark:index(n) ~= 0 then
table.insert(out, ')')
elseif mark ~= nil and n < 63 and y < 8 and mark:index(n+1) ~= 0 then
table.insert(out, '(')
else
table.insert(out, ' ')
end
end
table.insert(out, "\n")
end
table.insert(out, ' ' .. FILES:gsub('.', '%1 ') .. '\n')
table.insert(out, (self.toMove == 1 and "White" or "Black") .. ' e:' .. (self.material/100) .. "\n")
return table.concat(out)
end
function Board:moveList()
local tm = self.toMove == 1 and self.white or self.black
local castle_rank = self.toMove == 1 and Rank1 or Rank8
local out = {}
local function emit(id)
if not self:applyMove(id):illegalyChecked() then
table.insert(out, id)
end
end
local cr = tm:band(self.castle):band(castle_rank)
if not cr:empty() then
local p = self.toMove == 1 and 11 or 12
local tcolor = self.toMove == 1 and self.black or self.white
local kidx = self[p]:ctz()
local castle = bit32.replace(0, p, 20, 4)
castle = bit32.replace(castle, kidx, 6, 6)
castle = bit32.replace(castle, 1, 14)
local mustbeemptyl = LeftMasks[4]:bxor(FileA):band(castle_rank)
local cantbethreatened = FileD:bor(FileC):band(castle_rank):bor(self[p])
if
not cr:bandempty(FileA) and
mustbeemptyl:bandempty(self.ocupied) and
not self:isSquareThreatened(cantbethreatened, tcolor)
then
emit(bit32.replace(castle, kidx - 2, 0, 6))
end
local mustbeemptyr = RightMasks[3]:bxor(FileH):band(castle_rank)
if
not cr:bandempty(FileH) and
mustbeemptyr:bandempty(self.ocupied) and
not self:isSquareThreatened(mustbeemptyr:bor(self[p]), tcolor)
then
emit(bit32.replace(castle, kidx + 2, 0, 6))
end
end
local sq = tm:ctz()
repeat
local p = self:index(sq)
local moves = self:pmoves(sq)
while not moves:empty() do
local m = moves:ctz()
moves = moves:set(m, 0)
local id = bit32.replace(m, sq, 6, 6)
id = bit32.replace(id, p, 20, 4)
local mbb = Bitboard.some(m)
if not self.ocupied:bandempty(mbb) then
id = bit32.replace(id, self:index(m), 25, 4)
end
-- Check if pawn needs to be promoted
if p == 1 and m >= 8*7 then
for i=3,9,2 do
emit(bit32.replace(id, i, 15, 4))
end
elseif p == 2 and m < 8 then
for i=4,10,2 do
emit(bit32.replace(id, i, 15, 4))
end
else
emit(id)
end
end
sq = tm:ctzafter(sq)
until sq == 64
return out
end
function Board:illegalyChecked()
local target = self.toMove == 1 and self[PieceSymbols:find("k")] or self[PieceSymbols:find("K")]
return self:isSquareThreatened(target, self.toMove == 1 and self.white or self.black)
end
function Board:isSquareThreatened(target , color )
local tm = color
local sq = tm:ctz()
repeat
local moves = self:pmoves(sq)
if not moves:bandempty(target) then
return true
end
sq = color:ctzafter(sq)
until sq == 64
return false
end
function Board:perft(depth )
if depth == 0 then return 1 end
if depth == 1 then
return #self:moveList()
end
local result = 0
for k,m in ipairs(self:moveList()) do
local c = self:applyMove(m):perft(depth - 1)
if c == 0 then
-- Perft only counts leaf nodes at target depth
-- result = result + 1
else
result = result + c
end
end
return result
end
function Board:applyMove(move )
local out = Board.new()
table.move(self, 1, 12, 1, out)
local from = bit32.extract(move, 6, 6)
local to = bit32.extract(move, 0, 6)
local promote = bit32.extract(move, 15, 4)
local piece = self:index(from)
local captured = self:index(to)
local tom = Bitboard.some(to)
local isCastle = bit32.extract(move, 14)
if piece % 2 == 0 then
out.moves = self.moves + 1
end
if captured == 1 or piece < 3 then
out.hm = 0
else
out.hm = self.hm + 1
end
out.castle = self.castle
out.toMove = self.toMove == 1 and 2 or 1
if isCastle == 1 then
local rank = piece == 11 and Rank1 or Rank8
local colorOffset = piece - 11
out[3 + colorOffset] = out[3 + colorOffset]:bandnot(from < to and FileH or FileA)
out[3 + colorOffset] = out[3 + colorOffset]:bor((from < to and FileF or FileD):band(rank))
out[piece] = (from < to and FileG or FileC):band(rank)
out.castle = out.castle:bandnot(rank)
out:updateCache()
return out
end
if piece < 3 then
local dist = math.abs(to - from)
-- Pawn moved two squares, set ep square
if dist == 16 then
out.ep = Bitboard.some((from + to) / 2)
end
-- Remove enpasent capture
if not tom:bandempty(self.ep) then
if piece == 1 then
out[2] = out[2]:bandnot(self.ep:down())
end
if piece == 2 then
out[1] = out[1]:bandnot(self.ep:up())
end
end
end
if piece == 3 or piece == 4 then
out.castle = out.castle:set(from, 0)
end
if piece > 10 then
local rank = piece == 11 and Rank1 or Rank8
out.castle = out.castle:bandnot(rank)
end
out[piece] = out[piece]:set(from, 0)
if promote == 0 then
out[piece] = out[piece]:set(to, 1)
else
out[promote] = out[promote]:set(to, 1)
end
if captured ~= 0 then
out[captured] = out[captured]:set(to, 0)
end
out:updateCache()
return out
end
Board.__index = Board
Board.__tostring = Board.toString
--
-- Main
--
local failures = 0
local function test(fen, ply, target)
local b = Board.fromFen(fen)
if b:fen() ~= fen then
print("FEN MISMATCH", fen, b:fen())
failures = failures + 1
return
end
local found = b:perft(ply)
if found ~= target then
print(fen, "Found", found, "target", target)
failures = failures + 1
for k,v in pairs(b:moveList()) do
print(ucimove(v) .. ': ' .. (ply > 1 and b:applyMove(v):perft(ply-1) or '1'))
end
--error("Test Failure")
else
print("OK", found, fen)
end
end
-- From https://www.chessprogramming.org/Perft_Results
-- If interpreter, computers, or algorithm gets too fast
-- feel free to go deeper
local testCases = {}
local function addTest(...) table.insert(testCases, {...}) end
addTest(StartingFen, 2, 400)
addTest("r3k2r/p1ppqpb1/bn2pnp1/3PN3/1p2P3/2N2Q1p/PPPBBPPP/R3K2R w KQkq - 0 0", 1, 48)
addTest("8/2p5/3p4/KP5r/1R3p1k/8/4P1P1/8 w - - 0 0", 2, 191)
addTest("r3k2r/Pppp1ppp/1b3nbN/nP6/BBP1P3/q4N2/Pp1P2PP/R2Q1RK1 w kq - 0 1", 2, 264)
addTest("rnbq1k1r/pp1Pbppp/2p5/8/2B5/8/PPP1NnPP/RNBQK2R w KQ - 1 8", 1, 44)
addTest("r4rk1/1pp1qppp/p1np1n2/2b1p1B1/2B1P1b1/P1NP1N2/1PP1QPPP/R4RK1 w - - 0 10", 1, 46)
local function chess()
for k,v in ipairs(testCases) do
test(v[1],v[2],v[3])
end
end
bench.runCode(chess, "chess")