#!/usr/bin/python3
# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details

# NOTE: This script is experimental. This script uses a linear regression to construct a model for predicting native
# code size from bytecode. Some initial work has been done to analyze a large corpus of Luau scripts, and while for
# most functions the model predicts the native code size quite well (+/-25%), there are many cases where the predicted
# size is off by as much as 13x. Notably, the predicted size is generally better for smaller functions and worse for 
# larger functions. Therefore, in its current form this analysis is probably not suitable for use as a basis for 
# compilation heuristics. A nonlinear model may produce better results. The script here exists as a foundation for 
# further exploration.


import json
import glob
from pathlib import Path
import pandas as pd
import numpy as np
from sklearn.linear_model import LinearRegression
import matplotlib.pyplot as plt
import argparse


def readStats(statsFileGlob):
    '''Reads files matching the supplied glob.
    Files should be generated by the Compile.cpp CLI'''

    statsFiles = glob.glob(statsFileGlob, recursive=True)

    print("Reading %s files." % len(statsFiles))

    df_dict = {
        "statsFile": [],
        "script": [],
        "name": [],
        "line": [],
        "bcodeCount": [],
        "irCount": [],
        "asmCount": [],
        "bytecodeSummary": []
    }

    for statsFile in statsFiles:
        stats = json.loads(Path(statsFile).read_text())
        for script, filestats in stats.items():
            for funstats in filestats["lowerStats"]["functions"]:
                df_dict["statsFile"].append(statsFile)
                df_dict["script"].append(script)
                df_dict["name"].append(funstats["name"])
                df_dict["line"].append(funstats["line"])
                df_dict["bcodeCount"].append(funstats["bcodeCount"])
                df_dict["irCount"].append(funstats["irCount"])
                df_dict["asmCount"].append(funstats["asmCount"])
                df_dict["bytecodeSummary"].append(
                    tuple(funstats["bytecodeSummary"][0]))

    return pd.DataFrame.from_dict(df_dict)


def addFunctionCount(df):
    df2 = df.drop_duplicates(subset=['asmCount', 'bytecodeSummary'], ignore_index=True).groupby(
        ['bytecodeSummary']).size().reset_index(name='functionCount')
    return df.merge(df2, on='bytecodeSummary', how='left')

# def deduplicateDf(df):
#    return df.drop_duplicates(subset=['bcodeCount', 'asmCount', 'bytecodeSummary'], ignore_index=True)


def randomizeDf(df):
    return df.sample(frac=1)


def splitSeq(seq):
    n = len(seq) // 2
    return (seq[:n], seq[n:])


def trainAsmSizePredictor(df):
    XTrain, XValidate = splitSeq(
        np.array([list(seq) for seq in df.bytecodeSummary]))
    YTrain, YValidate = splitSeq(np.array(df.asmCount))

    reg = LinearRegression(
        positive=True, fit_intercept=False).fit(XTrain, YTrain)
    YPredict1 = reg.predict(XTrain)
    YPredict2 = reg.predict(XValidate)

    trainRmse = np.sqrt(np.mean((np.array(YPredict1) - np.array(YTrain))**2))
    predictRmse = np.sqrt(
        np.mean((np.array(YPredict2) - np.array(YValidate))**2))

    print(f"Score: {reg.score(XTrain, YTrain)}")
    print(f"Training RMSE: {trainRmse}")
    print(f"Prediction RMSE: {predictRmse}")
    print(f"Model Intercept: {reg.intercept_}")
    print(f"Model Coefficients:\n{reg.coef_}")

    df.loc[:, 'asmCountPredicted'] = np.concatenate(
        (YPredict1, YPredict2)).round().astype(int)
    df['usedForTraining'] = np.concatenate(
        (np.repeat(True, YPredict1.size), np.repeat(False, YPredict2.size)))
    df['diff'] = df['asmCountPredicted'] - df['asmCount']
    df['diffPerc'] = (100 * df['diff']) / df['asmCount']
    df.loc[(df["diffPerc"] == np.inf), 'diffPerc'] = 0.0
    df['diffPerc'] = df['diffPerc'].round()

    return (reg, df)


def saveModel(reg, file):
    f = open(file, "w")
    f.write(f"Intercept: {reg.intercept_}\n")
    f.write(f"Coefficients: \n{reg.coef_}\n")
    f.close()


def bcodeVsAsmPlot(df, plotFile=None, minBcodeCount=None, maxBcodeCount=None):
    if minBcodeCount is None:
        minBcodeCount = df.bcodeCount.min()
    if maxBcodeCount is None:
        maxBcodeCount = df.bcodeCount.max()

    subDf = df[(df.bcodeCount <= maxBcodeCount) &
               (df.bcodeCount >= minBcodeCount)]

    plt.scatter(subDf.bcodeCount, subDf.asmCount)
    plt.title("ASM variation by Bytecode")
    plt.xlabel("Bytecode Instruction Count")
    plt.ylabel("ASM Instruction Count")

    if plotFile is not None:
        plt.savefig(plotFile)

    return plt


def predictionErrorPlot(df, plotFile=None, minPerc=None, maxPerc=None, bins=200):
    if minPerc is None:
        minPerc = df['diffPerc'].min()
    if maxPerc is None:
        maxPerc = df['diffPerc'].max()

    plotDf = df[(df["usedForTraining"] == False) & (
        df["diffPerc"] >= minPerc) & (df["diffPerc"] <= maxPerc)]

    plt.hist(plotDf["diffPerc"], bins=bins)
    plt.title("Prediction Error Distribution")
    plt.xlabel("Prediction Error %")
    plt.ylabel("Function Count")

    if plotFile is not None:
        plt.savefig(plotFile)

    return plt


def parseArgs():
    parser = argparse.ArgumentParser(
        prog='codesizeprediction.py',
        description='Constructs a linear regression model to predict native instruction count from bytecode opcode distribution')
    parser.add_argument("fileglob",
                        help="glob pattern for stats files to be used for training")
    parser.add_argument("modelfile",
                        help="text file to save model details")
    parser.add_argument("--nativesizefig",
                        help="path for saving the plot showing the variation of native code size with bytecode")
    parser.add_argument("--predictionerrorfig",
                        help="path for saving the plot showing the distribution of prediction error")
    return parser.parse_args()


if __name__ == "__main__":
    args = parseArgs()

    df0 = readStats(args.fileglob)
    df1 = addFunctionCount(df0)
    df2 = randomizeDf(df1)

    plt = bcodeVsAsmPlot(df2, args.nativesizefig, 0, 100)
    plt.show()

    (reg, df4) = trainAsmSizePredictor(df2)
    saveModel(reg, args.modelfile)

    plt = predictionErrorPlot(df4, args.predictionerrorfig, -200, 200)
    plt.show()