diff --git a/src/Thread.ts b/src/Thread.ts new file mode 100644 index 0000000..1ac255b --- /dev/null +++ b/src/Thread.ts @@ -0,0 +1,54 @@ +import { LuaError } from './LuaError' +import { LuaType } from './utils' + +export type ThreadStatus = 'running' | 'suspended' | 'dead' + +type Gen = Generator + +type GenFn = (...args: LuaType[]) => Gen + +class Thread { + private readonly fn: GenFn + private gen?: Gen + public status: ThreadStatus = 'suspended' + public last: LuaType[] = [] + + public static current: Thread + public static main: Thread + + public constructor(fn: GenFn) { + this.fn = fn + } + + public resume(...args: LuaType[]): LuaType[] { + if (this.status === 'dead') { + throw new LuaError('cannot resume dead coroutine') + } + const prev = Thread.current + Thread.current = this + this.status = 'running' + try { + if (!this.gen) { + this.gen = this.fn(...args) + const r = this.gen.next() + this.status = r.done ? 'dead' : 'suspended' + this.last = r.value || [] + return this.last + } + const r = this.gen.next(args) + this.status = r.done ? 'dead' : 'suspended' + this.last = r.value || [] + return this.last + } finally { + Thread.current = prev + } + } +} + +// eslint-disable-next-line @typescript-eslint/no-empty-function +const mainThread = new Thread((function* () {}) as unknown as GenFn) +mainThread.status = 'running' +Thread.main = mainThread +Thread.current = mainThread + +export { Thread } diff --git a/src/index.ts b/src/index.ts index 7e4fccf..4ebd2e4 100644 --- a/src/index.ts +++ b/src/index.ts @@ -7,24 +7,36 @@ import { Table } from './Table' import { LuaError } from './LuaError' import { libMath } from './lib/math' import { libTable } from './lib/table' -import { libString, metatable as stringMetatable } from './lib/string' -import { getLibOS } from './lib/os' -import { getLibPackage } from './lib/package' -import { LuaType, ensureArray, Config } from './utils' +import { libString, metatable as stringMetatable } from './lib/string' +import { getLibOS } from './lib/os' +import { getLibPackage } from './lib/package' +import { libCoroutine } from './lib/coroutine' +import { LuaType, ensureArray, Config } from './utils' +import { Thread } from './Thread' import { parse as parseScript } from './parser' interface Script { exec: () => LuaType } -const call = (f: Function | Table, ...args: LuaType[]): LuaType[] => { - if (f instanceof Function) return ensureArray(f(...args)) - - const mm = f instanceof Table && f.getMetaMethod('__call') - if (mm) return ensureArray(mm(f, ...args)) - - throw new LuaError(`attempt to call an uncallable type`) -} +const call = (f: Function | Table | Thread, ...args: LuaType[]): LuaType[] => { + if (f instanceof Thread) return f.resume(...args) + + if (f instanceof Function) { + const res = f(...args) + if (res && typeof res.next === 'function') { + let r = res.next() + while (!r.done) r = res.next() + return ensureArray(r.value) + } + return ensureArray(res as LuaType) + } + + const mm = f instanceof Table && f.getMetaMethod('__call') + if (mm) return ensureArray(mm(f, ...args)) + + throw new LuaError(`attempt to call an uncallable type`) +} const stringTable = new Table() stringTable.metatable = stringMetatable @@ -37,18 +49,22 @@ const get = (t: Table | string, v: LuaType): LuaType => { } const execChunk = (_G: Table, chunk: string, chunkName?: string): LuaType[] => { - const exec = new Function('__lua', chunk) - const globalScope = new Scope(_G.strValues).extend() - if (chunkName) globalScope.setVarargs([chunkName]) - const res = exec({ - globalScope, - ...operators, - Table, - call, - get - }) - return res === undefined ? [undefined] : res -} + const exec = new Function(`return ${chunk}`)() as (lua: unknown) => Generator + const globalScope = new Scope(_G.strValues).extend() + if (chunkName) globalScope.setVarargs([chunkName]) + const iterator = exec({ + globalScope, + ...operators, + Table, + call, + get + }) + let res = iterator.next() + while (!res.done) { + res = iterator.next() + } + return res.value === undefined ? [undefined] : res.value +} function createEnv( config: Config = {} @@ -81,8 +97,9 @@ function createEnv( loadLib('package', libPackage) loadLib('math', libMath) loadLib('table', libTable) - loadLib('string', libString) - loadLib('os', getLibOS(cfg)) + loadLib('string', libString) + loadLib('os', getLibOS(cfg)) + loadLib('coroutine', libCoroutine) _G.rawset('require', _require) diff --git a/src/lib/coroutine.ts b/src/lib/coroutine.ts new file mode 100644 index 0000000..1c6e256 --- /dev/null +++ b/src/lib/coroutine.ts @@ -0,0 +1,43 @@ +import { Table } from '../Table' +import { LuaType, coerceArgToFunction, coerceArgToThread } from '../utils' +import { LuaError } from '../LuaError' +import { Thread } from '../Thread' + +function create(fn: LuaType): Thread { + const F = coerceArgToFunction(fn, 'create', 1) + return new Thread(F as (...args: LuaType[]) => Generator) +} + +function resume(thread: LuaType, ...args: LuaType[]): LuaType[] { + const THREAD = coerceArgToThread(thread, 'resume', 1) + try { + return [true, ...THREAD.resume(...args)] + } catch (e) { + if (e instanceof LuaError) return [false, e.message] + throw e + } +} + +function status(thread: LuaType): string { + const THREAD = coerceArgToThread(thread, 'status', 1) + return THREAD.status +} + +function wrap(fn: LuaType): Function { + const thread = create(fn) + return (...args: LuaType[]): LuaType[] => thread.resume(...args) +} + +function running(): LuaType[] { + return [Thread.current, Thread.current === Thread.main] +} + +const libCoroutine = new Table({ + create, + resume, + wrap, + running, + status +}) + +export { libCoroutine } diff --git a/src/parser.ts b/src/parser.ts index dca6690..ecbe020 100644 --- a/src/parser.ts +++ b/src/parser.ts @@ -149,15 +149,15 @@ const generate = (node: luaparse.Node): string | MemExpr => { return generate(node.expression) } - case 'FunctionDeclaration': { - const getFuncDef = (params: string[]): string => { - const paramStr = params.join(';\n') - const body = parseBody(node, paramStr) - const argsStr = params.length === 0 ? '' : '...args' - const returnStr = - node.body.findIndex(node => node.type === 'ReturnStatement') === -1 ? '\nreturn []' : '' - return `(${argsStr}) => {\n${body}${returnStr}\n}` - } + case 'FunctionDeclaration': { + const getFuncDef = (params: string[]): string => { + const paramStr = params.join(';\n') + const body = parseBody(node, paramStr) + const argsStr = params.length === 0 ? '' : '...args' + const returnStr = + node.body.findIndex(node => node.type === 'ReturnStatement') === -1 ? '\nreturn []' : '' + return `function* (${argsStr}) {\n${body}${returnStr}\n}` + } const params = node.parameters.map(param => { if (param.type === 'VarargLiteral') { @@ -210,10 +210,10 @@ const generate = (node: luaparse.Node): string | MemExpr => { return `for (let [iterator, table, next] = ${iterators}, res = __lua.call(iterator, table, next); res[0] !== undefined; res = __lua.call(iterator, table, res[0])) {\n${body}\n}` } - case 'Chunk': { - const body = parseBody(node) - return `'use strict'\nconst $0 = __lua.globalScope\nlet vars\nlet vals\nlet label\n\n${body}` - } + case 'Chunk': { + const body = parseBody(node) + return `function* (__lua) {\n'use strict'\nconst $0 = __lua.globalScope\nlet vars\nlet vals\nlet label\n\n${body}\n}` + } case 'Identifier': { return `$${nodeToScope.get(node)}.get('${node.name}')` @@ -321,21 +321,34 @@ const generate = (node: luaparse.Node): string | MemExpr => { return new MemExpr(base, index) } - case 'CallExpression': - case 'TableCallExpression': - case 'StringCallExpression': { - const functionName = expression(node.base) - const args = - node.type === 'CallExpression' - ? parseExpressionList(node.arguments).join(', ') - : expression(node.type === 'TableCallExpression' ? node.arguments : node.argument) - - if (functionName instanceof MemExpr && node.base.type === 'MemberExpression' && node.base.indexer === ':') { - return `__lua.call(${functionName}, ${functionName.base}, ${args})` - } - - return `__lua.call(${functionName}, ${args})` - } + case 'CallExpression': + case 'TableCallExpression': + case 'StringCallExpression': { + if ( + node.type === 'CallExpression' && + node.base.type === 'MemberExpression' && + node.base.base.type === 'Identifier' && + node.base.base.name === 'coroutine' && + node.base.identifier.type === 'Identifier' && + node.base.identifier.name === 'yield' && + node.base.indexer === '.' + ) { + const args = parseExpressions(node.arguments) + return `yield ${args}` + } + + const functionName = expression(node.base) + const args = + node.type === 'CallExpression' + ? parseExpressionList(node.arguments).join(', ') + : expression(node.type === 'TableCallExpression' ? node.arguments : node.argument) + + if (functionName instanceof MemExpr && node.base.type === 'MemberExpression' && node.base.indexer === ':') { + return `__lua.call(${functionName}, ${functionName.base}, ${args})` + } + + return `__lua.call(${functionName}, ${args})` + } default: throw new Error(`No generator found for: ${node.type}`) diff --git a/src/utils.ts b/src/utils.ts index 7f565b4..5810aa2 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -1,7 +1,8 @@ -import { LuaError } from './LuaError' -import { Table } from './Table' +import { LuaError } from './LuaError' +import { Table } from './Table' +import { Thread } from './Thread' -type LuaType = undefined | boolean | number | string | Function | Table // thread | userdata +type LuaType = undefined | boolean | number | string | Function | Table | Thread // userdata interface Config { LUA_PATH?: string @@ -18,7 +19,7 @@ const FLOATING_POINT_PATTERN = /^[-+]?[0-9]*\.?([0-9]+([eE][-+]?[0-9]+)?)?$/ /** Pattern to identify a hex string value that can validly be converted to a number in Lua */ const HEXIDECIMAL_CONSTANT_PATTERN = /^(-)?0x([0-9a-fA-F]*)\.?([0-9a-fA-F]*)$/ -function type(v: LuaType): 'string' | 'number' | 'boolean' | 'function' | 'nil' | 'table' { +function type(v: LuaType): 'string' | 'number' | 'boolean' | 'function' | 'nil' | 'table' | 'thread' { const t = typeof v switch (t) { @@ -31,11 +32,12 @@ function type(v: LuaType): 'string' | 'number' | 'boolean' | 'function' | 'nil' case 'function': return t - case 'object': - if (v instanceof Table) return 'table' - if (v instanceof Function) return 'function' - } -} + case 'object': + if (v instanceof Table) return 'table' + if (v instanceof Function) return 'function' + if (v instanceof Thread) return 'thread' + } +} function tostring(v: LuaType): string { if (v instanceof Table) { @@ -45,9 +47,13 @@ function tostring(v: LuaType): string { return valToStr(v, 'table: 0x') } - if (v instanceof Function) { - return valToStr(v, 'function: 0x') - } + if (v instanceof Function) { + return valToStr(v, 'function: 0x') + } + + if (v instanceof Thread) { + return valToStr(v, 'thread: 0x') + } return coerceToString(v) @@ -184,21 +190,29 @@ function coerceArgToTable(value: LuaType, funcName: string, index: number): Tabl } } -function coerceArgToFunction(value: LuaType, funcName: string, index: number): Function { - if (value instanceof Function) { - return value - } else { - const typ = type(value) - throw new LuaError(`bad argument #${index} to '${funcName}' (function expected, got ${typ})`) - } -} +function coerceArgToFunction(value: LuaType, funcName: string, index: number): Function { + if (value instanceof Function) { + return value + } else { + const typ = type(value) + throw new LuaError(`bad argument #${index} to '${funcName}' (function expected, got ${typ})`) + } +} + +function coerceArgToThread(value: LuaType, funcName: string, index: number): Thread { + if (value instanceof Thread) { + return value + } + const typ = type(value) + throw new LuaError(`bad argument #${index} to '${funcName}' (thread expected, got ${typ})`) +} const ensureArray = (value: T | T[]): T[] => (value instanceof Array ? value : [value]) const hasOwnProperty = (obj: Record | unknown[], key: string | number): boolean => Object.prototype.hasOwnProperty.call(obj, key) -export { +export { LuaType, Config, type, @@ -209,8 +223,9 @@ export { coerceToString, coerceArgToNumber, coerceArgToString, - coerceArgToTable, - coerceArgToFunction, + coerceArgToTable, + coerceArgToFunction, + coerceArgToThread, ensureArray, hasOwnProperty } diff --git a/tests/test.js b/tests/test.js index 1d18745..bc27821 100644 --- a/tests/test.js +++ b/tests/test.js @@ -54,4 +54,56 @@ let exitCode = 0 } } +{ + const luaEnv = luainjs.createEnv() + const script = luaEnv.parse(` + local co = coroutine.create(function(a) + local x, y = coroutine.yield(a + 1, a + 2) + return x + y + end) + local r1 = {coroutine.resume(co, 3)} + if r1[1] ~= true or r1[2] ~= 4 or r1[3] ~= 5 then return 'fail1' end + local r2 = {coroutine.resume(co, 5, 6)} + if r2[1] ~= true or r2[2] ~= 11 then return 'fail2' end + return 'ok' + `) + if (script.exec() !== 'ok') throw Error('coroutine resume failed') +} + +{ + const luaEnv = luainjs.createEnv() + const script = luaEnv.parse(` + local f = coroutine.wrap(function(a) + local b = coroutine.yield(a + 1) + return a + b + end) + local r1 = {f(3)} + if r1[1] ~= 4 then return 'fail1' end + local r2 = {f(5)} + if r2[1] ~= 8 then return 'fail2' end + return 'ok' + `) + if (script.exec() !== 'ok') throw Error('coroutine wrap failed') +} + +{ + const luaEnv = luainjs.createEnv() + const script = luaEnv.parse(` + local main, isMain = coroutine.running() + if not isMain then return 'fail1' end + local co + co = coroutine.create(function() + if coroutine.status(co) ~= 'running' then return 'fail2' end + local t, m = coroutine.running() + if t ~= co or m then return 'fail3' end + end) + if coroutine.status(co) ~= 'suspended' then return 'fail4' end + local ok = coroutine.resume(co) + if not ok then return 'fail5' end + if coroutine.status(co) ~= 'dead' then return 'fail6' end + return 'ok' + `) + if (script.exec() !== 'ok') throw Error('coroutine running/status failed') +} + process.exit(exitCode)