From c26d820902ac66740bf2054e0822b7024a67c4cf Mon Sep 17 00:00:00 2001 From: vegorov-rbx <75688451+vegorov-rbx@users.noreply.github.com> Date: Fri, 8 Dec 2023 23:50:16 +0200 Subject: [PATCH 1/2] Sync to upstream/release/606 (#1127) New Solver * Improvements to data flow analysis Native Code Generation * Block limit is now per-function instead of per-module Co-authored-by: Alexander McCord Co-authored-by: Andy Friesen Co-authored-by: Aviral Goel Co-authored-by: Vyacheslav Egorov --- Analysis/include/Luau/ConstraintGenerator.h | 2 +- Analysis/include/Luau/DataFlowGraph.h | 31 +++- Analysis/include/Luau/Def.h | 1 + Analysis/src/ConstraintGenerator.cpp | 44 +++-- Analysis/src/DataFlowGraph.cpp | 182 +++++++++++++------- Analysis/src/Def.cpp | 19 +- Analysis/src/Frontend.cpp | 108 +++--------- CLI/Bytecode.cpp | 2 +- CodeGen/src/CodeGen.cpp | 27 ++- CodeGen/src/CodeGenLower.h | 13 +- CodeGen/src/EmitCommon.h | 3 - Common/include/Luau/ExperimentalFlags.h | 2 +- tests/Conformance.test.cpp | 59 +++++++ tests/DataFlowGraph.test.cpp | 148 ++++++++++++++++ tests/RequireByString.test.cpp | 2 +- tests/TypeInfer.functions.test.cpp | 1 - tests/TypeInfer.refinements.test.cpp | 17 +- tests/TypeInfer.typestates.test.cpp | 84 +++++++++ tools/faillist.txt | 15 ++ 19 files changed, 546 insertions(+), 214 deletions(-) diff --git a/Analysis/include/Luau/ConstraintGenerator.h b/Analysis/include/Luau/ConstraintGenerator.h index a3e1092f..0808e84c 100644 --- a/Analysis/include/Luau/ConstraintGenerator.h +++ b/Analysis/include/Luau/ConstraintGenerator.h @@ -150,7 +150,7 @@ private: */ ScopePtr childScope(AstNode* node, const ScopePtr& parent); - std::optional lookup(Scope* scope, DefId def); + std::optional lookup(Scope* scope, DefId def, bool prototype = true); /** * Adds a new constraint with no dependencies to a given scope. diff --git a/Analysis/include/Luau/DataFlowGraph.h b/Analysis/include/Luau/DataFlowGraph.h index 083e5046..1a983490 100644 --- a/Analysis/include/Luau/DataFlowGraph.h +++ b/Analysis/include/Luau/DataFlowGraph.h @@ -74,8 +74,15 @@ private: struct DfgScope { + enum ScopeType + { + Linear, + Loop, + Function, + }; + DfgScope* parent; - bool isLoopScope; + ScopeType scopeType; using Bindings = DenseHashMap; using Props = DenseHashMap>; @@ -117,7 +124,17 @@ private: std::vector> scopes; - DfgScope* childScope(DfgScope* scope, bool isLoopScope = false); + struct FunctionCapture + { + std::vector captureDefs; + std::vector allVersions; + size_t versionOffset = 0; + }; + + DenseHashMap captures{Symbol{}}; + void resolveCaptures(); + + DfgScope* childScope(DfgScope* scope, DfgScope::ScopeType scopeType = DfgScope::Linear); void join(DfgScope* p, DfgScope* a, DfgScope* b); void joinBindings(DfgScope::Bindings& p, const DfgScope::Bindings& a, const DfgScope::Bindings& b); @@ -167,11 +184,11 @@ private: DataFlowResult visitExpr(DfgScope* scope, AstExprError* error); void visitLValue(DfgScope* scope, AstExpr* e, DefId incomingDef, bool isCompoundAssignment = false); - void visitLValue(DfgScope* scope, AstExprLocal* l, DefId incomingDef, bool isCompoundAssignment); - void visitLValue(DfgScope* scope, AstExprGlobal* g, DefId incomingDef, bool isCompoundAssignment); - void visitLValue(DfgScope* scope, AstExprIndexName* i, DefId incomingDef); - void visitLValue(DfgScope* scope, AstExprIndexExpr* i, DefId incomingDef); - void visitLValue(DfgScope* scope, AstExprError* e, DefId incomingDef); + DefId visitLValue(DfgScope* scope, AstExprLocal* l, DefId incomingDef, bool isCompoundAssignment); + DefId visitLValue(DfgScope* scope, AstExprGlobal* g, DefId incomingDef, bool isCompoundAssignment); + DefId visitLValue(DfgScope* scope, AstExprIndexName* i, DefId incomingDef); + DefId visitLValue(DfgScope* scope, AstExprIndexExpr* i, DefId incomingDef); + DefId visitLValue(DfgScope* scope, AstExprError* e, DefId incomingDef); void visitType(DfgScope* scope, AstType* t); void visitType(DfgScope* scope, AstTypeReference* r); diff --git a/Analysis/include/Luau/Def.h b/Analysis/include/Luau/Def.h index e3fec9b6..9627f998 100644 --- a/Analysis/include/Luau/Def.h +++ b/Analysis/include/Luau/Def.h @@ -73,6 +73,7 @@ const T* get(DefId def) } bool containsSubscriptedDefinition(DefId def); +void collectOperands(DefId def, std::vector* operands); struct DefArena { diff --git a/Analysis/src/ConstraintGenerator.cpp b/Analysis/src/ConstraintGenerator.cpp index cdfe9a7f..311cf9a3 100644 --- a/Analysis/src/ConstraintGenerator.cpp +++ b/Analysis/src/ConstraintGenerator.cpp @@ -205,7 +205,7 @@ ScopePtr ConstraintGenerator::childScope(AstNode* node, const ScopePtr& parent) return scope; } -std::optional ConstraintGenerator::lookup(Scope* scope, DefId def) +std::optional ConstraintGenerator::lookup(Scope* scope, DefId def, bool prototype) { if (get(def)) return scope->lookup(def); @@ -213,22 +213,24 @@ std::optional ConstraintGenerator::lookup(Scope* scope, DefId def) { if (auto found = scope->lookup(def)) return *found; + else if (!prototype) + return std::nullopt; TypeId res = builtinTypes->neverType; for (DefId operand : phi->operands) { - // `scope->lookup(operand)` may return nothing because it could be a phi node of globals, but one of - // the operand of that global has never been assigned a type, and so it should be an error. - // e.g. - // ``` - // if foo() then - // g = 5 - // end - // -- `g` here is a phi node of the assignment to `g`, or the original revision of `g` before the branch. - // ``` - TypeId ty = scope->lookup(operand).value_or(builtinTypes->errorRecoveryType()); - res = simplifyUnion(builtinTypes, arena, res, ty).result; + // `scope->lookup(operand)` may return nothing because we only bind a type to that operand + // once we've seen that particular `DefId`. In this case, we need to prototype those types + // and use those at a later time. + std::optional ty = lookup(scope, operand, /*prototype*/false); + if (!ty) + { + ty = arena->addType(BlockedType{}); + rootScope->lvalueTypes[operand] = *ty; + } + + res = simplifyUnion(builtinTypes, arena, res, *ty).result; } scope->lvalueTypes[def] = res; @@ -861,7 +863,7 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatFunction* f DenseHashSet excludeList{nullptr}; DefId def = dfg->getDef(function->name); - std::optional existingFunctionTy = scope->lookup(def); + std::optional existingFunctionTy = lookup(scope.get(), def); if (AstExprLocal* localName = function->name->as()) { @@ -1724,16 +1726,14 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprGlobal* globa /* prepopulateGlobalScope() has already added all global functions to the environment by this point, so any * global that is not already in-scope is definitely an unknown symbol. */ - if (auto ty = lookup(scope.get(), def)) - return Inference{*ty, refinementArena.proposition(key, builtinTypes->truthyType)}; - else if (auto ty = scope->lookup(global->name)) + if (auto ty = lookup(scope.get(), def, /*prototype=*/false)) { rootScope->lvalueTypes[def] = *ty; return Inference{*ty, refinementArena.proposition(key, builtinTypes->truthyType)}; } else { - reportError(global->location, UnknownSymbol{global->name.value}); + reportError(global->location, UnknownSymbol{global->name.value, UnknownSymbol::Binding}); return Inference{builtinTypes->errorRecoveryType()}; } } @@ -3110,6 +3110,16 @@ struct GlobalPrepopulator : AstVisitor return true; } + + bool visit(AstType*) override + { + return true; + } + + bool visit(class AstTypePack* node) override + { + return true; + } }; void ConstraintGenerator::prepopulateGlobalScope(const ScopePtr& globalScope, AstStatBlock* program) diff --git a/Analysis/src/DataFlowGraph.cpp b/Analysis/src/DataFlowGraph.cpp index bdefd7f0..b331474e 100644 --- a/Analysis/src/DataFlowGraph.cpp +++ b/Analysis/src/DataFlowGraph.cpp @@ -116,7 +116,7 @@ bool DfgScope::canUpdateDefinition(Symbol symbol) const { if (current->bindings.find(symbol)) return true; - else if (current->isLoopScope) + else if (current->scopeType == DfgScope::Loop) return false; } @@ -129,7 +129,7 @@ bool DfgScope::canUpdateDefinition(DefId def, const std::string& key) const { if (auto props = current->props.find(def)) return true; - else if (current->isLoopScope) + else if (current->scopeType == DfgScope::Loop) return false; } @@ -144,6 +144,7 @@ DataFlowGraph DataFlowGraphBuilder::build(AstStatBlock* block, NotNull operands; + for (size_t i = capture.versionOffset; i < capture.allVersions.size(); ++i) + collectOperands(capture.allVersions[i], &operands); + + for (DefId captureDef : capture.captureDefs) + { + Phi* phi = const_cast(get(captureDef)); + LUAU_ASSERT(phi); + LUAU_ASSERT(phi->operands.empty()); + phi->operands = operands; + } + } +} + +DfgScope* DataFlowGraphBuilder::childScope(DfgScope* scope, DfgScope::ScopeType scopeType) +{ + return scopes.emplace_back(new DfgScope{scope, scopeType}).get(); } void DataFlowGraphBuilder::join(DfgScope* p, DfgScope* a, DfgScope* b) @@ -227,24 +246,44 @@ void DataFlowGraphBuilder::joinProps(DfgScope::Props& p, const DfgScope::Props& DefId DataFlowGraphBuilder::lookup(DfgScope* scope, Symbol symbol) { - if (auto found = scope->lookup(symbol)) - return *found; - else + for (DfgScope* current = scope; current; current = current->parent) { - DefId result = defArena->freshCell(); - if (symbol.local) - scope->bindings[symbol] = result; - else - moduleScope->bindings[symbol] = result; - return result; + if (auto found = current->bindings.find(symbol)) + return NotNull{*found}; + else if (current->scopeType == DfgScope::Function) + { + FunctionCapture& capture = captures[symbol]; + DefId captureDef = defArena->phi({}); + capture.captureDefs.push_back(captureDef); + scope->bindings[symbol] = captureDef; + return NotNull{captureDef}; + } } + + DefId result = defArena->freshCell(); + scope->bindings[symbol] = result; + captures[symbol].allVersions.push_back(result); + return result; } DefId DataFlowGraphBuilder::lookup(DfgScope* scope, DefId def, const std::string& key) { - if (auto found = scope->lookup(def, key)) - return *found; - else if (auto phi = get(def)) + for (DfgScope* current = scope; current; current = current->parent) + { + if (auto props = current->props.find(def)) + { + if (auto it = props->find(key); it != props->end()) + return NotNull{it->second}; + } + else if (auto phi = get(def); phi && phi->operands.empty()) // Unresolved phi nodes + { + DefId result = defArena->freshCell(); + scope->props[def][key] = result; + return result; + } + } + + if (auto phi = get(def)) { std::vector defs; for (DefId operand : phi->operands) @@ -361,7 +400,7 @@ ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatIf* i) ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatWhile* w) { // TODO(controlflow): entry point has a back edge from exit point - DfgScope* whileScope = childScope(scope, /*isLoopScope=*/true); + DfgScope* whileScope = childScope(scope, DfgScope::Loop); visitExpr(whileScope, w->condition); visit(whileScope, w->body); @@ -373,7 +412,7 @@ ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatWhile* w) ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatRepeat* r) { // TODO(controlflow): entry point has a back edge from exit point - DfgScope* repeatScope = childScope(scope, /*isLoopScope=*/true); + DfgScope* repeatScope = childScope(scope, DfgScope::Loop); visitBlockWithoutChildScope(repeatScope, r->body); visitExpr(repeatScope, r->condition); @@ -429,6 +468,7 @@ ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocal* l) DefId def = defArena->freshCell(subscripted); graph.localDefs[local] = def; scope->bindings[local] = def; + captures[local].allVersions.push_back(def); } return ControlFlow::None; @@ -436,7 +476,7 @@ ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocal* l) ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFor* f) { - DfgScope* forScope = childScope(scope, /*isLoopScope=*/true); + DfgScope* forScope = childScope(scope, DfgScope::Loop); visitExpr(scope, f->from); visitExpr(scope, f->to); @@ -449,6 +489,7 @@ ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFor* f) DefId def = defArena->freshCell(); graph.localDefs[f->var] = def; scope->bindings[f->var] = def; + captures[f->var].allVersions.push_back(def); // TODO(controlflow): entry point has a back edge from exit point visit(forScope, f->body); @@ -460,7 +501,7 @@ ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFor* f) ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatForIn* f) { - DfgScope* forScope = childScope(scope, /*isLoopScope=*/true); + DfgScope* forScope = childScope(scope, DfgScope::Loop); for (AstLocal* local : f->vars) { @@ -470,6 +511,7 @@ ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatForIn* f) DefId def = defArena->freshCell(); graph.localDefs[local] = def; forScope->bindings[local] = def; + captures[local].allVersions.push_back(def); } // TODO(controlflow): entry point has a back edge from exit point @@ -527,10 +569,21 @@ ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFunction* f) // // which is evidence that references to variables must be a phi node of all possible definitions, // but for bug compatibility, we'll assume the same thing here. - DefId prototype = defArena->freshCell(); - visitLValue(scope, f->name, prototype); + visitLValue(scope, f->name, defArena->freshCell()); visitExpr(scope, f->func); + if (auto local = f->name->as()) + { + // local f + // function f() + // if cond() then + // f() -- should reference only the function version and other future version, and nothing prior + // end + // end + FunctionCapture& capture = captures[local->local]; + capture.versionOffset = capture.allVersions.size() - 1; + } + return ControlFlow::None; } @@ -539,6 +592,7 @@ ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocalFunction* l DefId def = defArena->freshCell(); graph.localDefs[l->name] = def; scope->bindings[l->name] = def; + captures[l->name].allVersions.push_back(def); visitExpr(scope, l->func); return ControlFlow::None; @@ -559,6 +613,7 @@ ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatDeclareGlobal* d DefId def = defArena->freshCell(); graph.declaredDefs[d] = def; scope->bindings[d->name] = def; + captures[d->name].allVersions.push_back(def); visitType(scope, d->type); @@ -570,6 +625,7 @@ ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatDeclareFunction* DefId def = defArena->freshCell(); graph.declaredDefs[d] = def; scope->bindings[d->name] = def; + captures[d->name].allVersions.push_back(def); DfgScope* unreachable = childScope(scope); visitGenerics(unreachable, d->generics); @@ -669,14 +725,9 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprGroup* gr DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprLocal* l) { - // DfgScope::lookup is intentional here: we want to be able to ice. - if (auto def = scope->lookup(l->local)) - { - const RefinementKey* key = keyArena->leaf(*def); - return {*def, key}; - } - - handle->ice("DFG: AstExprLocal came before its declaration?"); + DefId def = lookup(scope, l->local); + const RefinementKey* key = keyArena->leaf(def); + return {def, key}; } DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprGlobal* g) @@ -723,7 +774,7 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexExpr DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprFunction* f) { - DfgScope* signatureScope = childScope(scope); + DfgScope* signatureScope = childScope(scope, DfgScope::Function); if (AstLocal* self = f->self) { @@ -733,6 +784,7 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprFunction* DefId def = defArena->freshCell(); graph.localDefs[self] = def; signatureScope->bindings[self] = def; + captures[self].allVersions.push_back(def); } for (AstLocal* param : f->args) @@ -743,6 +795,7 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprFunction* DefId def = defArena->freshCell(); graph.localDefs[param] = def; signatureScope->bindings[param] = def; + captures[param].allVersions.push_back(def); } if (f->varargAnnotation) @@ -827,41 +880,46 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprError* er void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExpr* e, DefId incomingDef, bool isCompoundAssignment) { - if (auto l = e->as()) - return visitLValue(scope, l, incomingDef, isCompoundAssignment); - else if (auto g = e->as()) - return visitLValue(scope, g, incomingDef, isCompoundAssignment); - else if (auto i = e->as()) - return visitLValue(scope, i, incomingDef); - else if (auto i = e->as()) - return visitLValue(scope, i, incomingDef); - else if (auto error = e->as()) - return visitLValue(scope, error, incomingDef); - else - handle->ice("Unknown AstExpr in DataFlowGraphBuilder::visitLValue"); + auto go = [&]() { + if (auto l = e->as()) + return visitLValue(scope, l, incomingDef, isCompoundAssignment); + else if (auto g = e->as()) + return visitLValue(scope, g, incomingDef, isCompoundAssignment); + else if (auto i = e->as()) + return visitLValue(scope, i, incomingDef); + else if (auto i = e->as()) + return visitLValue(scope, i, incomingDef); + else if (auto error = e->as()) + return visitLValue(scope, error, incomingDef); + else + handle->ice("Unknown AstExpr in DataFlowGraphBuilder::visitLValue"); + }; + + graph.astDefs[e] = go(); } -void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprLocal* l, DefId incomingDef, bool isCompoundAssignment) +DefId DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprLocal* l, DefId incomingDef, bool isCompoundAssignment) { // We need to keep the previous def around for a compound assignment. if (isCompoundAssignment) { - if (auto def = scope->lookup(l->local)) - graph.compoundAssignDefs[l] = *def; + DefId def = lookup(scope, l->local); + graph.compoundAssignDefs[l] = def; } // In order to avoid alias tracking, we need to clip the reference to the parent def. if (scope->canUpdateDefinition(l->local)) { DefId updated = defArena->freshCell(containsSubscriptedDefinition(incomingDef)); - graph.astDefs[l] = updated; scope->bindings[l->local] = updated; + captures[l->local].allVersions.push_back(updated); + return updated; } else - visitExpr(scope, static_cast(l)); + return visitExpr(scope, static_cast(l)).def; } -void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprGlobal* g, DefId incomingDef, bool isCompoundAssignment) +DefId DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprGlobal* g, DefId incomingDef, bool isCompoundAssignment) { // We need to keep the previous def around for a compound assignment. if (isCompoundAssignment) @@ -874,28 +932,29 @@ void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprGlobal* g, DefId if (scope->canUpdateDefinition(g->name)) { DefId updated = defArena->freshCell(containsSubscriptedDefinition(incomingDef)); - graph.astDefs[g] = updated; scope->bindings[g->name] = updated; + captures[g->name].allVersions.push_back(updated); + return updated; } else - visitExpr(scope, static_cast(g)); + return visitExpr(scope, static_cast(g)).def; } -void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprIndexName* i, DefId incomingDef) +DefId DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprIndexName* i, DefId incomingDef) { DefId parentDef = visitExpr(scope, i->expr).def; if (scope->canUpdateDefinition(parentDef, i->index.value)) { DefId updated = defArena->freshCell(containsSubscriptedDefinition(incomingDef)); - graph.astDefs[i] = updated; scope->props[parentDef][i->index.value] = updated; + return updated; } else - visitExpr(scope, static_cast(i)); + return visitExpr(scope, static_cast(i)).def; } -void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprIndexExpr* i, DefId incomingDef) +DefId DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprIndexExpr* i, DefId incomingDef) { DefId parentDef = visitExpr(scope, i->expr).def; visitExpr(scope, i->index); @@ -905,20 +964,19 @@ void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprIndexExpr* i, Def if (scope->canUpdateDefinition(parentDef, string->value.data)) { DefId updated = defArena->freshCell(containsSubscriptedDefinition(incomingDef)); - graph.astDefs[i] = updated; scope->props[parentDef][string->value.data] = updated; + return updated; } else - visitExpr(scope, static_cast(i)); + return visitExpr(scope, static_cast(i)).def; } - - graph.astDefs[i] = defArena->freshCell(); + else + return defArena->freshCell(/*subscripted=*/true); } -void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprError* error, DefId incomingDef) +DefId DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprError* error, DefId incomingDef) { - DefId def = visitExpr(scope, error).def; - graph.astDefs[error] = def; + return visitExpr(scope, error).def; } void DataFlowGraphBuilder::visitType(DfgScope* scope, AstType* t) diff --git a/Analysis/src/Def.cpp b/Analysis/src/Def.cpp index 2b3bbeac..1ecaca22 100644 --- a/Analysis/src/Def.cpp +++ b/Analysis/src/Def.cpp @@ -19,17 +19,13 @@ bool containsSubscriptedDefinition(DefId def) return false; } -DefId DefArena::freshCell(bool subscripted) +void collectOperands(DefId def, std::vector* operands) { - return NotNull{allocator.allocate(Def{Cell{subscripted}})}; -} - -static void collectOperands(DefId def, std::vector& operands) -{ - if (std::find(operands.begin(), operands.end(), def) != operands.end()) + LUAU_ASSERT(operands); + if (std::find(operands->begin(), operands->end(), def) != operands->end()) return; else if (get(def)) - operands.push_back(def); + operands->push_back(def); else if (auto phi = get(def)) { for (const Def* operand : phi->operands) @@ -37,6 +33,11 @@ static void collectOperands(DefId def, std::vector& operands) } } +DefId DefArena::freshCell(bool subscripted) +{ + return NotNull{allocator.allocate(Def{Cell{subscripted}})}; +} + DefId DefArena::phi(DefId a, DefId b) { return phi({a, b}); @@ -46,7 +47,7 @@ DefId DefArena::phi(const std::vector& defs) { std::vector operands; for (DefId operand : defs) - collectOperands(operand, operands); + collectOperands(operand, &operands); // There's no need to allocate a Phi node for a singleton set. if (operands.size() == 1) diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 125f2457..59ef9373 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -32,11 +32,9 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) -LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100) // TODO: Remove with FFlagLuauTypecheckLimitControls LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false) LUAU_FASTFLAGVARIABLE(DebugLuauReadWriteProperties, false) -LUAU_FASTFLAGVARIABLE(LuauTypecheckLimitControls, false) LUAU_FASTFLAGVARIABLE(CorrectEarlyReturnInMarkDirty, false) LUAU_FASTFLAGVARIABLE(LuauDefinitionFileSetModuleName, false) @@ -902,82 +900,41 @@ void Frontend::checkBuildQueueItem(BuildQueueItem& item) TypeCheckLimits typeCheckLimits; - if (FFlag::LuauTypecheckLimitControls) + if (item.options.moduleTimeLimitSec) + typeCheckLimits.finishTime = TimeTrace::getClock() + *item.options.moduleTimeLimitSec; + else + typeCheckLimits.finishTime = std::nullopt; + + // TODO: This is a dirty ad hoc solution for autocomplete timeouts + // We are trying to dynamically adjust our existing limits to lower total typechecking time under the limit + // so that we'll have type information for the whole file at lower quality instead of a full abort in the middle + if (item.options.applyInternalLimitScaling) { - if (item.options.moduleTimeLimitSec) - typeCheckLimits.finishTime = TimeTrace::getClock() + *item.options.moduleTimeLimitSec; + if (FInt::LuauTarjanChildLimit > 0) + typeCheckLimits.instantiationChildLimit = std::max(1, int(FInt::LuauTarjanChildLimit * sourceNode.autocompleteLimitsMult)); else - typeCheckLimits.finishTime = std::nullopt; + typeCheckLimits.instantiationChildLimit = std::nullopt; - // TODO: This is a dirty ad hoc solution for autocomplete timeouts - // We are trying to dynamically adjust our existing limits to lower total typechecking time under the limit - // so that we'll have type information for the whole file at lower quality instead of a full abort in the middle - if (item.options.applyInternalLimitScaling) - { - if (FInt::LuauTarjanChildLimit > 0) - typeCheckLimits.instantiationChildLimit = std::max(1, int(FInt::LuauTarjanChildLimit * sourceNode.autocompleteLimitsMult)); - else - typeCheckLimits.instantiationChildLimit = std::nullopt; - - if (FInt::LuauTypeInferIterationLimit > 0) - typeCheckLimits.unifierIterationLimit = std::max(1, int(FInt::LuauTypeInferIterationLimit * sourceNode.autocompleteLimitsMult)); - else - typeCheckLimits.unifierIterationLimit = std::nullopt; - } - - typeCheckLimits.cancellationToken = item.options.cancellationToken; + if (FInt::LuauTypeInferIterationLimit > 0) + typeCheckLimits.unifierIterationLimit = std::max(1, int(FInt::LuauTypeInferIterationLimit * sourceNode.autocompleteLimitsMult)); + else + typeCheckLimits.unifierIterationLimit = std::nullopt; } + typeCheckLimits.cancellationToken = item.options.cancellationToken; + if (item.options.forAutocomplete) { - double autocompleteTimeLimit = FInt::LuauAutocompleteCheckTimeoutMs / 1000.0; - - if (!FFlag::LuauTypecheckLimitControls) - { - // The autocomplete typecheck is always in strict mode with DM awareness - // to provide better type information for IDE features - - if (autocompleteTimeLimit != 0.0) - typeCheckLimits.finishTime = TimeTrace::getClock() + autocompleteTimeLimit; - else - typeCheckLimits.finishTime = std::nullopt; - - // TODO: This is a dirty ad hoc solution for autocomplete timeouts - // We are trying to dynamically adjust our existing limits to lower total typechecking time under the limit - // so that we'll have type information for the whole file at lower quality instead of a full abort in the middle - if (FInt::LuauTarjanChildLimit > 0) - typeCheckLimits.instantiationChildLimit = std::max(1, int(FInt::LuauTarjanChildLimit * sourceNode.autocompleteLimitsMult)); - else - typeCheckLimits.instantiationChildLimit = std::nullopt; - - if (FInt::LuauTypeInferIterationLimit > 0) - typeCheckLimits.unifierIterationLimit = std::max(1, int(FInt::LuauTypeInferIterationLimit * sourceNode.autocompleteLimitsMult)); - else - typeCheckLimits.unifierIterationLimit = std::nullopt; - - typeCheckLimits.cancellationToken = item.options.cancellationToken; - } - // The autocomplete typecheck is always in strict mode with DM awareness to provide better type information for IDE features ModulePtr moduleForAutocomplete = check(sourceModule, Mode::Strict, requireCycles, environmentScope, /*forAutocomplete*/ true, /*recordJsonLog*/ false, typeCheckLimits); double duration = getTimestamp() - timestamp; - if (FFlag::LuauTypecheckLimitControls) - { - moduleForAutocomplete->checkDurationSec = duration; + moduleForAutocomplete->checkDurationSec = duration; - if (item.options.moduleTimeLimitSec && item.options.applyInternalLimitScaling) - applyInternalLimitScaling(sourceNode, moduleForAutocomplete, *item.options.moduleTimeLimitSec); - } - else - { - if (moduleForAutocomplete->timeout) - sourceNode.autocompleteLimitsMult = sourceNode.autocompleteLimitsMult / 2.0; - else if (duration < autocompleteTimeLimit / 2.0) - sourceNode.autocompleteLimitsMult = std::min(sourceNode.autocompleteLimitsMult * 2.0, 1.0); - } + if (item.options.moduleTimeLimitSec && item.options.applyInternalLimitScaling) + applyInternalLimitScaling(sourceNode, moduleForAutocomplete, *item.options.moduleTimeLimitSec); item.stats.timeCheck += duration; item.stats.filesStrict += 1; @@ -986,29 +943,16 @@ void Frontend::checkBuildQueueItem(BuildQueueItem& item) return; } - if (!FFlag::LuauTypecheckLimitControls) - { - typeCheckLimits.cancellationToken = item.options.cancellationToken; - } - ModulePtr module = check(sourceModule, mode, requireCycles, environmentScope, /*forAutocomplete*/ false, item.recordJsonLog, typeCheckLimits); - if (FFlag::LuauTypecheckLimitControls) - { - double duration = getTimestamp() - timestamp; + double duration = getTimestamp() - timestamp; - module->checkDurationSec = duration; + module->checkDurationSec = duration; - if (item.options.moduleTimeLimitSec && item.options.applyInternalLimitScaling) - applyInternalLimitScaling(sourceNode, module, *item.options.moduleTimeLimitSec); - - item.stats.timeCheck += duration; - } - else - { - item.stats.timeCheck += getTimestamp() - timestamp; - } + if (item.options.moduleTimeLimitSec && item.options.applyInternalLimitScaling) + applyInternalLimitScaling(sourceNode, module, *item.options.moduleTimeLimitSec); + item.stats.timeCheck += duration; item.stats.filesStrict += mode == Mode::Strict; item.stats.filesNonstrict += mode == Mode::Nonstrict; diff --git a/CLI/Bytecode.cpp b/CLI/Bytecode.cpp index 5002ce1d..76faa6fe 100644 --- a/CLI/Bytecode.cpp +++ b/CLI/Bytecode.cpp @@ -124,7 +124,7 @@ static bool analyzeFile(const char* name, const unsigned nestingLimit, std::vect { Luau::BytecodeBuilder bcb; - compileOrThrow(bcb, source.value(), copts()); + compileOrThrow(bcb, *source, copts()); const std::string& bytecode = bcb.getBytecode(); diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index 3e6f5e8b..84d3f900 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -42,8 +42,18 @@ LUAU_FASTFLAGVARIABLE(DebugCodegenNoOpt, false) LUAU_FASTFLAGVARIABLE(DebugCodegenOptSize, false) LUAU_FASTFLAGVARIABLE(DebugCodegenSkipNumbering, false) + +// Per-module IR instruction count limit LUAU_FASTINTVARIABLE(CodegenHeuristicsInstructionLimit, 1'048'576) // 1 M -LUAU_FASTINTVARIABLE(CodegenHeuristicsBlockLimit, 65'536) // 64 K + +// Per-function IR block limit +// Current value is based on some member variables being limited to 16 bits +// Because block check is made before optimization passes and optimization can generate new blocks, limit is lowered 2x +// The limit will probably be adjusted in the future to avoid performance issues with analysis that's more complex than O(n) +LUAU_FASTINTVARIABLE(CodegenHeuristicsBlockLimit, 32'768) // 32 K + +// Per-function IR instruction limit +// Current value is based on some member variables being limited to 16 bits LUAU_FASTINTVARIABLE(CodegenHeuristicsBlockInstructionLimit, 65'536) // 64 K namespace Luau @@ -104,11 +114,18 @@ static void logPerfFunction(Proto* p, uintptr_t addr, unsigned size) } template -static std::optional createNativeFunction(AssemblyBuilder& build, ModuleHelpers& helpers, Proto* proto) +static std::optional createNativeFunction(AssemblyBuilder& build, ModuleHelpers& helpers, Proto* proto, uint32_t& totalIrInstCount) { IrBuilder ir; ir.buildFunctionIr(proto); + unsigned instCount = unsigned(ir.function.instructions.size()); + + if (totalIrInstCount + instCount >= unsigned(FInt::CodegenHeuristicsInstructionLimit.value)) + return std::nullopt; + + totalIrInstCount += instCount; + if (!lowerFunction(ir, build, helpers, proto, {}, /* stats */ nullptr)) return std::nullopt; @@ -291,9 +308,13 @@ CodeGenCompilationResult compile(lua_State* L, int idx, unsigned int flags, Comp std::vector results; results.reserve(protos.size()); + uint32_t totalIrInstCount = 0; + for (Proto* p : protos) - if (std::optional np = createNativeFunction(build, helpers, p)) + { + if (std::optional np = createNativeFunction(build, helpers, p, totalIrInstCount)) results.push_back(*np); + } // Very large modules might result in overflowing a jump offset; in this case we currently abandon the entire module if (!build.finalize()) diff --git a/CodeGen/src/CodeGenLower.h b/CodeGen/src/CodeGenLower.h index 2ebe4349..3075ac9a 100644 --- a/CodeGen/src/CodeGenLower.h +++ b/CodeGen/src/CodeGenLower.h @@ -253,11 +253,6 @@ inline bool lowerIr(A64::AssemblyBuilderA64& build, IrBuilder& ir, const std::ve template inline bool lowerFunction(IrBuilder& ir, AssemblyBuilder& build, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options, LoweringStats* stats) { - helpers.bytecodeInstructionCount += unsigned(ir.function.instructions.size()); - - if (helpers.bytecodeInstructionCount >= unsigned(FInt::CodegenHeuristicsInstructionLimit.value)) - return false; - killUnusedBlocks(ir.function); unsigned preOptBlockCount = 0; @@ -268,9 +263,7 @@ inline bool lowerFunction(IrBuilder& ir, AssemblyBuilder& build, ModuleHelpers& preOptBlockCount += (block.kind != IrBlockKind::Dead); unsigned blockInstructions = block.finish - block.start; maxBlockInstructions = std::max(maxBlockInstructions, blockInstructions); - }; - - helpers.preOptBlockCount += preOptBlockCount; + } // we update stats before checking the heuristic so that even if we bail out // our stats include information about the limit that was exceeded. @@ -280,9 +273,7 @@ inline bool lowerFunction(IrBuilder& ir, AssemblyBuilder& build, ModuleHelpers& stats->maxBlockInstructions = maxBlockInstructions; } - // we use helpers.blocksPreOpt instead of stats.blocksPreOpt since - // stats can be null across some code paths. - if (helpers.preOptBlockCount >= unsigned(FInt::CodegenHeuristicsBlockLimit.value)) + if (preOptBlockCount >= unsigned(FInt::CodegenHeuristicsBlockLimit.value)) return false; if (maxBlockInstructions >= unsigned(FInt::CodegenHeuristicsBlockInstructionLimit.value)) diff --git a/CodeGen/src/EmitCommon.h b/CodeGen/src/EmitCommon.h index f0e14103..013ba88f 100644 --- a/CodeGen/src/EmitCommon.h +++ b/CodeGen/src/EmitCommon.h @@ -31,9 +31,6 @@ struct ModuleHelpers // A64 Label continueCall; // x0: closure - - unsigned bytecodeInstructionCount = 0; - unsigned preOptBlockCount = 0; }; } // namespace CodeGen diff --git a/Common/include/Luau/ExperimentalFlags.h b/Common/include/Luau/ExperimentalFlags.h index 56f5632f..7372cc0d 100644 --- a/Common/include/Luau/ExperimentalFlags.h +++ b/Common/include/Luau/ExperimentalFlags.h @@ -14,7 +14,7 @@ inline bool isFlagExperimental(const char* flag) "LuauInstantiateInSubtyping", // requires some fixes to lua-apps code "LuauTinyControlFlowAnalysis", // waiting for updates to packages depended by internal builtin plugins "LuauFixIndexerSubtypingOrdering", // requires some small fixes to lua-apps code since this fixes a false negative - "LuauUpdatedRequireByStringSemantics", // requires some small fixes to fully implement some proposed changes + "LuauUpdatedRequireByStringSemantics", // requires some small fixes to fully implement some proposed changes // makes sure we always have at least one entry nullptr, }; diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 514211e1..a9c5bc37 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -33,6 +33,7 @@ LUAU_FASTFLAG(LuauCodeGenFixByteLower); LUAU_FASTFLAG(LuauCompileBufferAnnotation); LUAU_FASTFLAG(LuauLoopInterruptFix); LUAU_DYNAMIC_FASTFLAG(LuauStricterUtf8); +LUAU_FASTINT(CodegenHeuristicsInstructionLimit); static lua_CompileOptions defaultOptions() { @@ -2020,6 +2021,64 @@ TEST_CASE("HugeFunction") CHECK(lua_tonumber(L, -1) == 42); } +TEST_CASE("IrInstructionLimit") +{ + if (!codegen || !luau_codegen_supported()) + return; + + ScopedFastInt codegenHeuristicsInstructionLimit{FInt::CodegenHeuristicsInstructionLimit, 50'000}; + + std::string source; + + // Generate a hundred fat functions + for (int fn = 0; fn < 100; fn++) + { + source += "local function fn" + std::to_string(fn) + "(...)\n"; + source += "if ... then\n"; + source += "local p1, p2 = ...\n"; + source += "local _ = {\n"; + + for (int i = 0; i < 100; ++i) + { + source += "p1*0." + std::to_string(i) + ","; + source += "p2+0." + std::to_string(i) + ","; + } + + source += "}\n"; + source += "return _\n"; + source += "end\n"; + source += "end\n"; + } + + StateRef globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + luau_codegen_create(L); + + luaL_openlibs(L); + luaL_sandbox(L); + luaL_sandboxthread(L); + + size_t bytecodeSize = 0; + char* bytecode = luau_compile(source.data(), source.size(), nullptr, &bytecodeSize); + int result = luau_load(L, "=HugeFunction", bytecode, bytecodeSize, 0); + free(bytecode); + + REQUIRE(result == 0); + + Luau::CodeGen::CompilationStats nativeStats = {}; + Luau::CodeGen::CodeGenCompilationResult nativeResult = Luau::CodeGen::compile(L, -1, Luau::CodeGen::CodeGen_ColdFunctions, &nativeStats); + + // Limit is not hit immediately, so with some functions compiled it should be a success + CHECK(nativeResult != Luau::CodeGen::CodeGenCompilationResult::CodeGenFailed); + + // We should be able to compile at least one of our functions + CHECK(nativeStats.functionsCompiled > 0); + + // But because of the limit, not all of them (101 because there's an extra global function) + CHECK(nativeStats.functionsCompiled < 101); +} + TEST_CASE("BytecodeDistributionPerFunctionTest") { const char* source = R"( diff --git a/tests/DataFlowGraph.test.cpp b/tests/DataFlowGraph.test.cpp index 3488b90b..2ed26d97 100644 --- a/tests/DataFlowGraph.test.cpp +++ b/tests/DataFlowGraph.test.cpp @@ -341,6 +341,7 @@ TEST_CASE_FIXTURE(DataFlowGraphFixture, "property_lookup_on_a_phi_node") const Phi* phi = get(x3); REQUIRE(phi); + REQUIRE(phi->operands.size() == 2); CHECK(phi->operands.at(0) == x1); CHECK(phi->operands.at(1) == x2); } @@ -368,6 +369,7 @@ TEST_CASE_FIXTURE(DataFlowGraphFixture, "property_lookup_on_a_phi_node_2") const Phi* phi = get(x3); REQUIRE(phi); + REQUIRE(phi->operands.size() == 2); CHECK(phi->operands.at(0) == x2); CHECK(phi->operands.at(1) == x1); } @@ -408,8 +410,154 @@ TEST_CASE_FIXTURE(DataFlowGraphFixture, "property_lookup_on_a_phi_node_3") const Phi* phi = get(x3); REQUIRE(phi); + REQUIRE(phi->operands.size() == 2); CHECK(phi->operands.at(0) == x1); CHECK(phi->operands.at(1) == x2); } +TEST_CASE_FIXTURE(DataFlowGraphFixture, "function_captures_are_phi_nodes_of_all_versions") +{ + dfg(R"( + local x = 5 + + function f() + print(x) + x = nil + end + + f() + x = "five" + )"); + + DefId x1 = graph->getDef(query(module)->vars.data[0]); + DefId x2 = getDef(); // print(x) + DefId x3 = getDef(); // x = nil + DefId x4 = getDef(); // x = "five" + + CHECK(x1 != x2); + CHECK(x2 != x3); + CHECK(x3 != x4); + + const Phi* phi = get(x2); + REQUIRE(phi); + REQUIRE(phi->operands.size() == 3); + CHECK(phi->operands.at(0) == x1); + CHECK(phi->operands.at(1) == x3); + CHECK(phi->operands.at(2) == x4); +} + +TEST_CASE_FIXTURE(DataFlowGraphFixture, "function_captures_are_phi_nodes_of_all_versions_properties") +{ + dfg(R"( + local t = {} + t.x = 5 + + function f() + print(t.x) + t.x = nil + end + + f() + t.x = "five" + )"); + + DefId x1 = getDef(); // t.x = 5 + DefId x2 = getDef(); // print(t.x) + DefId x3 = getDef(); // t.x = nil + DefId x4 = getDef(); // t.x = "five" + + CHECK(x1 != x2); + CHECK(x2 != x3); + CHECK(x3 != x4); + + // When a local is referenced within a function, it is not pointer identical. + // Instead, it's a phi node of all possible versions, including just one version. + DefId t1 = graph->getDef(query(module)->vars.data[0]); + DefId t2 = getDef(); // print(t.x) + + const Phi* phi = get(t2); + REQUIRE(phi); + REQUIRE(phi->operands.size() == 1); + CHECK(phi->operands.at(0) == t1); +} + +TEST_CASE_FIXTURE(DataFlowGraphFixture, "local_f_which_is_prototyped_enclosed_by_function") +{ + dfg(R"( + local f + function f() + if cond() then + f() + end + end + )"); + + DefId f1 = graph->getDef(query(module)->vars.data[0]); + DefId f2 = getDef(); // function f() + DefId f3 = getDef(); // f() + + CHECK(f1 != f2); + CHECK(f2 != f3); + + const Phi* phi = get(f3); + REQUIRE(phi); + REQUIRE(phi->operands.size() == 1); + CHECK(phi->operands.at(0) == f2); +} + +TEST_CASE_FIXTURE(DataFlowGraphFixture, "local_f_which_is_prototyped_enclosed_by_function_has_some_prior_versions") +{ + dfg(R"( + local f + f = 5 + function f() + if cond() then + f() + end + end + )"); + + DefId f1 = graph->getDef(query(module)->vars.data[0]); + DefId f2 = getDef(); // f = 5 + DefId f3 = getDef(); // function f() + DefId f4 = getDef(); // f() + + CHECK(f1 != f2); + CHECK(f2 != f3); + CHECK(f3 != f4); + + const Phi* phi = get(f4); + REQUIRE(phi); + REQUIRE(phi->operands.size() == 1); + CHECK(phi->operands.at(0) == f3); +} + +TEST_CASE_FIXTURE(DataFlowGraphFixture, "local_f_which_is_prototyped_enclosed_by_function_has_some_future_versions") +{ + dfg(R"( + local f + function f() + if cond() then + f() + end + end + f = 5 + )"); + + DefId f1 = graph->getDef(query(module)->vars.data[0]); + DefId f2 = getDef(); // function f() + DefId f3 = getDef(); // f() + DefId f4 = getDef(); // f = 5 + + CHECK(f1 != f2); + CHECK(f2 != f3); + CHECK(f3 != f4); + + const Phi* phi = get(f3); + REQUIRE(phi); + REQUIRE(phi->operands.size() == 2); + CHECK(phi->operands.at(0) == f2); + CHECK(phi->operands.at(1) == f4); +} + TEST_SUITE_END(); diff --git a/tests/RequireByString.test.cpp b/tests/RequireByString.test.cpp index 0e5ad4ff..f04d7fdf 100644 --- a/tests/RequireByString.test.cpp +++ b/tests/RequireByString.test.cpp @@ -71,6 +71,7 @@ public: luauDirAbs += "/luau"; } + if (type == PathType::Relative) return luauDirRel; if (type == PathType::Absolute) @@ -214,7 +215,6 @@ TEST_CASE("PathNormalization") } } - TEST_CASE_FIXTURE(ReplWithPathFixture, "RequireSimpleRelativePath") { ScopedFastFlag sff{FFlag::LuauUpdatedRequireByStringSemantics, true}; diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 3cda9120..bc6477ad 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -736,7 +736,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "mutual_recursion") )"); LUAU_REQUIRE_NO_ERRORS(result); - dumpErrors(result); } TEST_CASE_FIXTURE(BuiltinsFixture, "toposort_doesnt_break_mutual_recursion") diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index e199ab1b..6dab86cc 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -1925,13 +1925,10 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "refine_unknown_to_table") // this test is DCR-only as an instance of DCR fixing a bug in the old solver CheckResult result = check(R"( - local idx, val - local function f(a: unknown) if typeof(a) == "table" then for i, v in a do - idx = i - val = v + return i, v end end end @@ -1939,17 +1936,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "refine_unknown_to_table") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - // Bug: We do not simplify at the right time - CHECK_EQ("unknown?", toString(requireType("idx"))); - CHECK_EQ("unknown?", toString(requireType("val"))); - } - else - { - CHECK_EQ("unknown", toString(requireType("idx"))); - CHECK_EQ("unknown", toString(requireType("val"))); - } + CHECK_EQ("(unknown) -> (unknown, unknown)", toString(requireType("f"))); } TEST_CASE_FIXTURE(BuiltinsFixture, "conditional_refinement_should_stay_error_suppressing") diff --git a/tests/TypeInfer.typestates.test.cpp b/tests/TypeInfer.typestates.test.cpp index e8044968..f269e129 100644 --- a/tests/TypeInfer.typestates.test.cpp +++ b/tests/TypeInfer.typestates.test.cpp @@ -315,4 +315,88 @@ TEST_CASE_FIXTURE(TypeStateFixture, "local_t_is_assigned_a_fresh_table_with_x_as CHECK("boolean | number | number | string" == toString(requireType("x"))); } +TEST_CASE_FIXTURE(TypeStateFixture, "captured_locals_are_unions_of_all_assignments") +{ + CheckResult result = check(R"( + local x = nil + + function f() + print(x) + x = "five" + end + + x = 5 + f() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK("(number | string)?" == toString(requireTypeAtPosition({4, 18}))); +} + +TEST_CASE_FIXTURE(TypeStateFixture, "captured_locals_are_unions_of_all_assignments_2") +{ + CheckResult result = check(R"( + local t = {x = nil} + + function f() + print(t.x) + t = {x = "five"} + end + + t = {x = 5} + f() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK("{ x: nil } | { x: number } | { x: string }" == toString(requireTypeAtPosition({4, 18}), {true})); + CHECK("(number | string)?" == toString(requireTypeAtPosition({4, 20}))); +} + +TEST_CASE_FIXTURE(TypeStateFixture, "prototyped_recursive_functions") +{ + CheckResult result = check(R"( + local f + function f() + if math.random() > 0.5 then + f() + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK("(() -> ())?" == toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(TypeStateFixture, "prototyped_recursive_functions_but_has_future_assignments") +{ + CheckResult result = check(R"( + local f + function f() + if math.random() > 0.5 then + f() + end + end + f = 5 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK("((() -> ()) | number)?" == toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(TypeStateFixture, "prototyped_recursive_functions_but_has_previous_assignments") +{ + CheckResult result = check(R"( + local f + f = 5 + function f() + if math.random() > 0.5 then + f() + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK("((() -> ()) | number)?" == toString(requireType("f"))); +} + TEST_SUITE_END(); diff --git a/tools/faillist.txt b/tools/faillist.txt index 0ba9c249..03b4d90d 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -268,9 +268,12 @@ ProvisionalTests.typeguard_inference_incomplete ProvisionalTests.while_body_are_also_refined ProvisionalTests.xpcall_returns_what_f_returns RefinementTest.assert_a_to_be_truthy_then_assert_a_to_be_number +RefinementTest.assert_non_binary_expressions_actually_resolve_constraints +RefinementTest.correctly_lookup_a_shadowed_local_that_which_was_previously_refined RefinementTest.correctly_lookup_property_whose_base_was_previously_refined RefinementTest.dataflow_analysis_can_tell_refinements_when_its_appropriate_to_refine_into_nil_or_never RefinementTest.discriminate_from_truthiness_of_x +RefinementTest.either_number_or_string RefinementTest.fail_to_refine_a_property_of_subscript_expression RefinementTest.falsiness_of_TruthyPredicate_narrows_into_nil RefinementTest.function_call_with_colon_after_refining_not_to_be_nil @@ -288,6 +291,7 @@ RefinementTest.string_not_equal_to_string_or_nil RefinementTest.truthy_constraint_on_properties RefinementTest.type_annotations_arent_relevant_when_doing_dataflow_analysis RefinementTest.type_comparison_ifelse_expression +RefinementTest.type_guard_narrowed_into_nothingness RefinementTest.type_narrow_to_vector RefinementTest.typeguard_cast_free_table_to_vector RefinementTest.typeguard_in_assert_position @@ -394,6 +398,7 @@ TableTests.table_unifies_into_map TableTests.top_table_type TableTests.type_mismatch_on_massive_table_is_cut_short TableTests.unification_of_unions_in_a_self_referential_type +TableTests.unifying_tables_shouldnt_uaf1 TableTests.used_colon_instead_of_dot TableTests.used_dot_instead_of_colon TableTests.used_dot_instead_of_colon_but_correctly @@ -403,6 +408,7 @@ ToDot.function ToString.exhaustive_toString_of_cyclic_table ToString.free_types ToString.named_metatable_toStringNamedFunction +ToString.no_parentheses_around_cyclic_function_type_in_intersection ToString.pick_distinct_names_for_mixed_explicit_and_implicit_generics ToString.primitive ToString.tostring_unsee_ttv_if_array @@ -466,6 +472,7 @@ TypeInfer.no_stack_overflow_from_isoptional TypeInfer.promote_tail_type_packs TypeInfer.recursive_function_that_invokes_itself_with_a_refinement_of_its_parameter TypeInfer.recursive_function_that_invokes_itself_with_a_refinement_of_its_parameter_2 +TypeInfer.statements_are_topologically_sorted TypeInfer.stringify_nested_unions_with_optionals TypeInfer.tc_after_error_recovery_no_replacement_name_in_error TypeInfer.type_infer_recursion_limit_no_ice @@ -481,6 +488,7 @@ TypeInferAnyError.for_in_loop_iterator_is_error TypeInferAnyError.for_in_loop_iterator_is_error2 TypeInferAnyError.for_in_loop_iterator_returns_any TypeInferAnyError.intersection_of_any_can_have_props +TypeInferAnyError.metatable_of_any_can_be_a_table TypeInferAnyError.quantify_any_does_not_bind_to_itself TypeInferAnyError.replace_every_free_type_when_unifying_a_complex_function_with_any TypeInferAnyError.type_error_addition @@ -502,6 +510,7 @@ TypeInferFunctions.apply_of_lambda_with_inferred_and_explicit_types TypeInferFunctions.calling_function_with_anytypepack_doesnt_leak_free_types TypeInferFunctions.cannot_hoist_interior_defns_into_signature TypeInferFunctions.check_function_bodies +TypeInferFunctions.complicated_return_types_require_an_explicit_annotation TypeInferFunctions.concrete_functions_are_not_supertypes_of_function TypeInferFunctions.dont_assert_when_the_tarjan_limit_is_exceeded_during_generalization TypeInferFunctions.dont_give_other_overloads_message_if_only_one_argument_matching_overload_exists @@ -538,7 +547,9 @@ TypeInferFunctions.it_is_ok_to_oversaturate_a_higher_order_function_argument TypeInferFunctions.list_all_overloads_if_no_overload_takes_given_argument_count TypeInferFunctions.list_only_alternative_overloads_that_match_argument_count TypeInferFunctions.luau_subtyping_is_np_hard +TypeInferFunctions.mutual_recursion TypeInferFunctions.no_lossy_function_type +TypeInferFunctions.num_is_solved_after_num_or_str TypeInferFunctions.occurs_check_failure_in_function_return_type TypeInferFunctions.other_things_are_not_related_to_function TypeInferFunctions.param_1_and_2_both_takes_the_same_generic_but_their_arguments_are_incompatible @@ -554,6 +565,7 @@ TypeInferFunctions.too_many_arguments TypeInferFunctions.too_many_arguments_error_location TypeInferFunctions.too_many_return_values_in_parentheses TypeInferFunctions.too_many_return_values_no_function +TypeInferFunctions.toposort_doesnt_break_mutual_recursion TypeInferFunctions.vararg_function_is_quantified TypeInferLoops.cli_68448_iterators_need_not_accept_nil TypeInferLoops.dcr_iteration_explore_raycast_minimization @@ -598,6 +610,7 @@ TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_another_overload_works TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_it_wont_help_2 TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_not_defined_with_colon TypeInferOOP.inferring_hundreds_of_self_calls_should_not_suffocate_memory +TypeInferOOP.method_depends_on_table TypeInferOOP.methods_are_topologically_sorted TypeInferOOP.object_constructor_can_refer_to_method_of_self TypeInferOOP.promise_type_error_too_complex @@ -662,6 +675,7 @@ UnionTypes.disallow_less_specific_assign2 UnionTypes.error_detailed_optional UnionTypes.error_detailed_union_all UnionTypes.error_detailed_union_part +UnionTypes.error_takes_optional_arguments UnionTypes.generic_function_with_optional_arg UnionTypes.index_on_a_union_type_with_missing_property UnionTypes.index_on_a_union_type_with_mixed_types @@ -681,6 +695,7 @@ UnionTypes.optional_union_follow UnionTypes.optional_union_functions UnionTypes.optional_union_members UnionTypes.optional_union_methods +UnionTypes.return_types_can_be_disjoint UnionTypes.table_union_write_indirect UnionTypes.unify_unsealed_table_union_check UnionTypes.union_of_functions From 2173938eb08ed35d4dc3eea86ddabcfc01fe5b9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Petri=20H=C3=A4kkinen?= Date: Fri, 15 Dec 2023 01:05:51 +0200 Subject: [PATCH 2/2] Add tagged lightuserdata (#1087) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This change adds support for tagged lightuserdata and optional custom typenames for lightuserdata. Background: Lightuserdata is an efficient representation for many kinds of unmanaged handles and resources in a game engine. However, currently the VM only supports one kind of lightuserdata, which makes it problematic in practice. For example, it's not possible to distinguish between different kinds of lightuserdata in Lua bindings, which can lead to unsafe practices and even crashes when a wrong kind of lightuserdata is passed to a binding function. Tagged lightuserdata work similarly to tagged userdata, i.e. they allow checking the tag quickly using lua_tolightuserdatatagged (or lua_lightuserdatatag). The tag is stored in the 'extra' field of TValue so it will add no cost to the (untagged) lightuserdata type. Alternatives would be to use full userdata values or use bitpacking to embed type information into lightuserdata on application level. Unfortunately these options are not that great in practice: full userdata have major performance implications and bitpacking fails in cases where full 64 bits are already used (e.g. pointers or 64-bit hashes). Lightuserdata names are not strictly necessary but they are rather convenient when debugging Lua code. More precise error messages and tostring returning more specific typename are useful to have in practice (e.g. "resource" or "entity" instead of the more generic "userdata"). Impl note: I did not add support for renaming tags in lua_setlightuserdataname as I'm not sure if it's possible to free fixed strings. If it's simple enough, maybe we should allow renaming (although I can't think of a specific need for it)? --------- Co-authored-by: Petri Häkkinen --- CodeGen/src/CodeGenUtils.cpp | 8 ++++---- VM/include/lua.h | 8 +++++++- VM/include/luaconf.h | 5 +++++ VM/src/lapi.cpp | 37 ++++++++++++++++++++++++++++++++++-- VM/src/lobject.cpp | 4 ++-- VM/src/lobject.h | 10 +++++++++- VM/src/lstate.cpp | 2 ++ VM/src/lstate.h | 2 ++ VM/src/ltm.cpp | 12 ++++++++++++ VM/src/lvmexecute.cpp | 16 +++++++++------- VM/src/lvmutils.cpp | 2 +- tests/Conformance.test.cpp | 27 ++++++++++++++++++++++++++ 12 files changed, 115 insertions(+), 18 deletions(-) diff --git a/CodeGen/src/CodeGenUtils.cpp b/CodeGen/src/CodeGenUtils.cpp index c1a9c338..973829ca 100644 --- a/CodeGen/src/CodeGenUtils.cpp +++ b/CodeGen/src/CodeGenUtils.cpp @@ -71,7 +71,7 @@ bool forgLoopTableIter(lua_State* L, Table* h, int index, TValue* ra) if (!ttisnil(e)) { - setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1))); + setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1)), LU_TAG_ITERATOR); setnvalue(ra + 3, double(index + 1)); setobj2s(L, ra + 4, e); @@ -90,7 +90,7 @@ bool forgLoopTableIter(lua_State* L, Table* h, int index, TValue* ra) if (!ttisnil(gval(n))) { - setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1))); + setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1)), LU_TAG_ITERATOR); getnodekey(L, ra + 3, n); setobj(L, ra + 4, gval(n)); @@ -115,7 +115,7 @@ bool forgLoopNodeIter(lua_State* L, Table* h, int index, TValue* ra) if (!ttisnil(gval(n))) { - setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1))); + setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1)), LU_TAG_ITERATOR); getnodekey(L, ra + 3, n); setobj(L, ra + 4, gval(n)); @@ -697,7 +697,7 @@ const Instruction* executeFORGPREP(lua_State* L, const Instruction* pc, StkId ba { // set up registers for builtin iteration setobj2s(L, ra + 1, ra); - setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); + setpvalue(ra + 2, reinterpret_cast(uintptr_t(0)), LU_TAG_ITERATOR); setnilvalue(ra); } else diff --git a/VM/include/lua.h b/VM/include/lua.h index dbf19cb4..0390de7c 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -159,9 +159,11 @@ LUA_API const char* lua_namecallatom(lua_State* L, int* atom); LUA_API int lua_objlen(lua_State* L, int idx); LUA_API lua_CFunction lua_tocfunction(lua_State* L, int idx); LUA_API void* lua_tolightuserdata(lua_State* L, int idx); +LUA_API void* lua_tolightuserdatatagged(lua_State* L, int idx, int tag); LUA_API void* lua_touserdata(lua_State* L, int idx); LUA_API void* lua_touserdatatagged(lua_State* L, int idx, int tag); LUA_API int lua_userdatatag(lua_State* L, int idx); +LUA_API int lua_lightuserdatatag(lua_State* L, int idx); LUA_API lua_State* lua_tothread(lua_State* L, int idx); LUA_API void* lua_tobuffer(lua_State* L, int idx, size_t* len); LUA_API const void* lua_topointer(lua_State* L, int idx); @@ -186,7 +188,7 @@ LUA_API void lua_pushcclosurek(lua_State* L, lua_CFunction fn, const char* debug LUA_API void lua_pushboolean(lua_State* L, int b); LUA_API int lua_pushthread(lua_State* L); -LUA_API void lua_pushlightuserdata(lua_State* L, void* p); +LUA_API void lua_pushlightuserdatatagged(lua_State* L, void* p, int tag); LUA_API void* lua_newuserdatatagged(lua_State* L, size_t sz, int tag); LUA_API void* lua_newuserdatadtor(lua_State* L, size_t sz, void (*dtor)(void*)); @@ -323,6 +325,9 @@ typedef void (*lua_Destructor)(lua_State* L, void* userdata); LUA_API void lua_setuserdatadtor(lua_State* L, int tag, lua_Destructor dtor); LUA_API lua_Destructor lua_getuserdatadtor(lua_State* L, int tag); +LUA_API void lua_setlightuserdataname(lua_State* L, int tag, const char* name); +LUA_API const char* lua_getlightuserdataname(lua_State* L, int tag); + LUA_API void lua_clonefunction(lua_State* L, int idx); LUA_API void lua_cleartable(lua_State* L, int idx); @@ -370,6 +375,7 @@ LUA_API void lua_unref(lua_State* L, int ref); #define lua_pushliteral(L, s) lua_pushlstring(L, "" s, (sizeof(s) / sizeof(char)) - 1) #define lua_pushcfunction(L, fn, debugname) lua_pushcclosurek(L, fn, debugname, 0, NULL) #define lua_pushcclosure(L, fn, debugname, nup) lua_pushcclosurek(L, fn, debugname, nup, NULL) +#define lua_pushlightuserdata(L, p) lua_pushlightuserdatatagged(L, p, 0) #define lua_setglobal(L, s) lua_setfield(L, LUA_GLOBALSINDEX, (s)) #define lua_getglobal(L, s) lua_getfield(L, LUA_GLOBALSINDEX, (s)) diff --git a/VM/include/luaconf.h b/VM/include/luaconf.h index 7a1bbb95..910e259a 100644 --- a/VM/include/luaconf.h +++ b/VM/include/luaconf.h @@ -101,6 +101,11 @@ #define LUA_UTAG_LIMIT 128 #endif +// number of valid Lua lightuserdata tags +#ifndef LUA_LUTAG_LIMIT +#define LUA_LUTAG_LIMIT 128 +#endif + // upper bound for number of size classes used by page allocator #ifndef LUA_SIZECLASSES #define LUA_SIZECLASSES 32 diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index 355e4e21..58c767f1 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -505,6 +505,12 @@ void* lua_tolightuserdata(lua_State* L, int idx) return (!ttislightuserdata(o)) ? NULL : pvalue(o); } +void* lua_tolightuserdatatagged(lua_State* L, int idx, int tag) +{ + StkId o = index2addr(L, idx); + return (!ttislightuserdata(o) || lightuserdatatag(o) != tag) ? NULL : pvalue(o); +} + void* lua_touserdata(lua_State* L, int idx) { StkId o = index2addr(L, idx); @@ -530,6 +536,14 @@ int lua_userdatatag(lua_State* L, int idx) return -1; } +int lua_lightuserdatatag(lua_State* L, int idx) +{ + StkId o = index2addr(L, idx); + if (ttislightuserdata(o)) + return lightuserdatatag(o); + return -1; +} + lua_State* lua_tothread(lua_State* L, int idx) { StkId o = index2addr(L, idx); @@ -665,9 +679,10 @@ void lua_pushboolean(lua_State* L, int b) api_incr_top(L); } -void lua_pushlightuserdata(lua_State* L, void* p) +void lua_pushlightuserdatatagged(lua_State* L, void* p, int tag) { - setpvalue(L->top, p); + api_check(L, unsigned(tag) < LUA_LUTAG_LIMIT); + setpvalue(L->top, p, tag); api_incr_top(L); } @@ -1412,6 +1427,24 @@ lua_Destructor lua_getuserdatadtor(lua_State* L, int tag) return L->global->udatagc[tag]; } +void lua_setlightuserdataname(lua_State* L, int tag, const char* name) +{ + api_check(L, unsigned(tag) < LUA_LUTAG_LIMIT); + api_check(L, !L->global->lightuserdataname[tag]); // renaming not supported + if (!L->global->lightuserdataname[tag]) + { + L->global->lightuserdataname[tag] = luaS_new(L, name); + luaS_fix(L->global->lightuserdataname[tag]); // never collect these names + } +} + +const char* lua_getlightuserdataname(lua_State* L, int tag) +{ + api_check(L, unsigned(tag) < LUA_LUTAG_LIMIT); + const TString* name = L->global->lightuserdataname[tag]; + return name ? getstr(name) : nullptr; +} + void lua_clonefunction(lua_State* L, int idx) { luaC_checkGC(L); diff --git a/VM/src/lobject.cpp b/VM/src/lobject.cpp index 88d8d7ca..081e3314 100644 --- a/VM/src/lobject.cpp +++ b/VM/src/lobject.cpp @@ -48,7 +48,7 @@ int luaO_rawequalObj(const TValue* t1, const TValue* t2) case LUA_TBOOLEAN: return bvalue(t1) == bvalue(t2); // boolean true must be 1 !! case LUA_TLIGHTUSERDATA: - return pvalue(t1) == pvalue(t2); + return pvalue(t1) == pvalue(t2) && (!FFlag::TaggedLuData || lightuserdatatag(t1) == lightuserdatatag(t2)); default: LUAU_ASSERT(iscollectable(t1)); return gcvalue(t1) == gcvalue(t2); @@ -71,7 +71,7 @@ int luaO_rawequalKey(const TKey* t1, const TValue* t2) case LUA_TBOOLEAN: return bvalue(t1) == bvalue(t2); // boolean true must be 1 !! case LUA_TLIGHTUSERDATA: - return pvalue(t1) == pvalue(t2); + return pvalue(t1) == pvalue(t2) && (!FFlag::TaggedLuData || lightuserdatatag(t1) == lightuserdatatag(t2)); default: LUAU_ASSERT(iscollectable(t1)); return gcvalue(t1) == gcvalue(t2); diff --git a/VM/src/lobject.h b/VM/src/lobject.h index 71640140..d236f7e4 100644 --- a/VM/src/lobject.h +++ b/VM/src/lobject.h @@ -80,6 +80,11 @@ typedef struct lua_TValue #define l_isfalse(o) (ttisnil(o) || (ttisboolean(o) && bvalue(o) == 0)) +#define lightuserdatatag(o) check_exp(ttislightuserdata(o), (o)->extra[0]) + +// Internal tags used by the VM +#define LU_TAG_ITERATOR LUA_UTAG_LIMIT + /* ** for internal debug only */ @@ -120,10 +125,11 @@ typedef struct lua_TValue } #endif -#define setpvalue(obj, x) \ +#define setpvalue(obj, x, tag) \ { \ TValue* i_o = (obj); \ i_o->value.p = (x); \ + i_o->extra[0] = (tag); \ i_o->tt = LUA_TLIGHTUSERDATA; \ } @@ -492,3 +498,5 @@ LUAI_FUNC int luaO_str2d(const char* s, double* result); LUAI_FUNC const char* luaO_pushvfstring(lua_State* L, const char* fmt, va_list argp); LUAI_FUNC const char* luaO_pushfstring(lua_State* L, const char* fmt, ...); LUAI_FUNC const char* luaO_chunkid(char* buf, size_t buflen, const char* source, size_t srclen); + +LUAU_FASTFLAG(TaggedLuData) diff --git a/VM/src/lstate.cpp b/VM/src/lstate.cpp index 161dcda0..858f61a3 100644 --- a/VM/src/lstate.cpp +++ b/VM/src/lstate.cpp @@ -210,6 +210,8 @@ lua_State* lua_newstate(lua_Alloc f, void* ud) g->mt[i] = NULL; for (i = 0; i < LUA_UTAG_LIMIT; i++) g->udatagc[i] = NULL; + for (i = 0; i < LUA_LUTAG_LIMIT; i++) + g->lightuserdataname[i] = NULL; for (i = 0; i < LUA_MEMORY_CATEGORIES; i++) g->memcatbytes[i] = 0; diff --git a/VM/src/lstate.h b/VM/src/lstate.h index ed73d3d8..2c6d35dc 100644 --- a/VM/src/lstate.h +++ b/VM/src/lstate.h @@ -214,6 +214,8 @@ typedef struct global_State void (*udatagc[LUA_UTAG_LIMIT])(lua_State*, void*); // for each userdata tag, a gc callback to be called immediately before freeing memory + TString* lightuserdataname[LUA_LUTAG_LIMIT]; // names for tagged lightuserdata + GCStats gcstats; #ifdef LUAI_GCMETRICS diff --git a/VM/src/ltm.cpp b/VM/src/ltm.cpp index 927a535b..3a9fddaa 100644 --- a/VM/src/ltm.cpp +++ b/VM/src/ltm.cpp @@ -129,6 +129,18 @@ const TString* luaT_objtypenamestr(lua_State* L, const TValue* o) if (ttisstring(type)) return tsvalue(type); } + else if (FFlag::TaggedLuData && ttislightuserdata(o)) + { + int tag = lightuserdatatag(o); + + if (unsigned(tag) < LUA_LUTAG_LIMIT) + { + const TString* name = L->global->lightuserdataname[tag]; + + if (name) + return name; + } + } else if (Table* mt = L->global->mt[ttype(o)]) { const TValue* type = luaH_getstr(mt, L->global->tmname[TM_TYPE]); diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index c1a3ca8e..1c77fa14 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -135,6 +135,8 @@ // Does VM support native execution via ExecutionCallbacks? We mostly assume it does but keep the define to make it easy to quantify the cost. #define VM_HAS_NATIVE 1 +LUAU_FASTFLAGVARIABLE(TaggedLuData, false) + LUAU_NOINLINE void luau_callhook(lua_State* L, lua_Hook hook, void* userdata) { ptrdiff_t base = savestack(L, L->base); @@ -1110,7 +1112,7 @@ reentry: VM_NEXT(); case LUA_TLIGHTUSERDATA: - pc += pvalue(ra) == pvalue(rb) ? LUAU_INSN_D(insn) : 1; + pc += (pvalue(ra) == pvalue(rb) && (!FFlag::TaggedLuData || lightuserdatatag(ra) == lightuserdatatag(rb))) ? LUAU_INSN_D(insn) : 1; LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); VM_NEXT(); @@ -1225,7 +1227,7 @@ reentry: VM_NEXT(); case LUA_TLIGHTUSERDATA: - pc += pvalue(ra) != pvalue(rb) ? LUAU_INSN_D(insn) : 1; + pc += (pvalue(ra) != pvalue(rb) || (FFlag::TaggedLuData && lightuserdatatag(ra) != lightuserdatatag(rb))) ? LUAU_INSN_D(insn) : 1; LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); VM_NEXT(); @@ -2296,7 +2298,7 @@ reentry: { // set up registers for builtin iteration setobj2s(L, ra + 1, ra); - setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); + setpvalue(ra + 2, reinterpret_cast(uintptr_t(0)), LU_TAG_ITERATOR); setnilvalue(ra); } else @@ -2348,7 +2350,7 @@ reentry: if (!ttisnil(e)) { - setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1))); + setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1)), LU_TAG_ITERATOR); setnvalue(ra + 3, double(index + 1)); setobj2s(L, ra + 4, e); @@ -2369,7 +2371,7 @@ reentry: if (!ttisnil(gval(n))) { - setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1))); + setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1)), LU_TAG_ITERATOR); getnodekey(L, ra + 3, n); setobj2s(L, ra + 4, gval(n)); @@ -2421,7 +2423,7 @@ reentry: { setnilvalue(ra); // ra+1 is already the table - setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); + setpvalue(ra + 2, reinterpret_cast(uintptr_t(0)), LU_TAG_ITERATOR); } else if (!ttisfunction(ra)) { @@ -2450,7 +2452,7 @@ reentry: { setnilvalue(ra); // ra+1 is already the table - setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); + setpvalue(ra + 2, reinterpret_cast(uintptr_t(0)), LU_TAG_ITERATOR); } else if (!ttisfunction(ra)) { diff --git a/VM/src/lvmutils.cpp b/VM/src/lvmutils.cpp index c4b0b47d..851d778c 100644 --- a/VM/src/lvmutils.cpp +++ b/VM/src/lvmutils.cpp @@ -288,7 +288,7 @@ int luaV_equalval(lua_State* L, const TValue* t1, const TValue* t2) case LUA_TBOOLEAN: return bvalue(t1) == bvalue(t2); // true must be 1 !! case LUA_TLIGHTUSERDATA: - return pvalue(t1) == pvalue(t2); + return pvalue(t1) == pvalue(t2) && (!FFlag::TaggedLuData || lightuserdatatag(t1) == lightuserdatatag(t2)); case LUA_TUSERDATA: { tm = get_compTM(L, uvalue(t1)->metatable, uvalue(t2)->metatable, TM_EQ); diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index a9c5bc37..b530ce55 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -32,6 +32,7 @@ LUAU_FASTFLAG(LuauBufferDefinitions); LUAU_FASTFLAG(LuauCodeGenFixByteLower); LUAU_FASTFLAG(LuauCompileBufferAnnotation); LUAU_FASTFLAG(LuauLoopInterruptFix); +LUAU_FASTFLAG(TaggedLuData); LUAU_DYNAMIC_FASTFLAG(LuauStricterUtf8); LUAU_FASTINT(CodegenHeuristicsInstructionLimit); @@ -1700,6 +1701,32 @@ TEST_CASE("UserdataApi") CHECK(dtorhits == 42); } +TEST_CASE("LightuserdataApi") +{ + ScopedFastFlag taggedLuData{FFlag::TaggedLuData, true}; + + StateRef globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + void* value = (void*)0x12345678; + + lua_pushlightuserdatatagged(L, value, 1); + CHECK(lua_lightuserdatatag(L, -1) == 1); + CHECK(lua_tolightuserdatatagged(L, -1, 0) == nullptr); + CHECK(lua_tolightuserdatatagged(L, -1, 1) == value); + + lua_setlightuserdataname(L, 1, "id"); + CHECK(!lua_getlightuserdataname(L, 0)); + CHECK(strcmp(lua_getlightuserdataname(L, 1), "id") == 0); + CHECK(strcmp(luaL_typename(L, -1), "id") == 0); + + lua_pushlightuserdatatagged(L, value, 0); + lua_pushlightuserdatatagged(L, value, 1); + CHECK(lua_rawequal(L, -1, -2) == 0); + + globalState.reset(); +} + TEST_CASE("Iter") { runConformance("iter.lua");