Add control flow extraction to jit compiler.

This commit is contained in:
James Roseborough
2008-06-29 18:34:22 +00:00
parent a3b939352d
commit 88770a3630
2 changed files with 125 additions and 141 deletions

View File

@@ -21,7 +21,6 @@
******************************************************************************/
package org.luaj.jit;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
@@ -38,14 +37,11 @@ import javax.tools.JavaCompiler.CompilationTask;
import org.luaj.compiler.LuaC;
import org.luaj.debug.Print;
import org.luaj.platform.J2sePlatform;
import org.luaj.vm.LClosure;
import org.luaj.vm.LPrototype;
import org.luaj.vm.LValue;
import org.luaj.vm.LoadState;
import org.luaj.vm.Lua;
import org.luaj.vm.LuaState;
import org.luaj.vm.Platform;
import org.luaj.vm.LoadState.LuaCompiler;
public class LuaJit extends Lua implements LuaCompiler {
@@ -57,42 +53,6 @@ public class LuaJit extends Lua implements LuaCompiler {
LoadState.compiler = new LuaJit();
}
public static void main(String[] args) throws IOException {
Platform.setInstance(new J2sePlatform());
LuaC.install();
String program = "print 'starting'\n" +
"for i=1,10 do\n" +
" print 'hello, world'\n" +
"end";
program = "print 'a'\n" +
"if a then\n" +
" print 'a'\n" +
"elseif b then\n" +
" print 'b'\n" +
"else\n" +
" print 'c'\n" +
"end\n" +
"print 'd'\n";
InputStream is = new ByteArrayInputStream(program.getBytes());
LPrototype p = LuaC.compile(is, "program");
test( p );
LPrototype q = LuaJit.jitCompile( p );
test( q );
}
private static void test(LPrototype p) {
try {
LuaState vm = Platform.newLuaState();
LClosure c = p.newClosure(vm._G);
vm.pushlvalue(c);
vm.call(0, 0);
} catch ( Throwable e ) {
e.printStackTrace();
}
}
private static int filenum = 0;
private static synchronized String filename() {
@@ -166,17 +126,76 @@ public class LuaJit extends Lua implements LuaCompiler {
private static boolean isjump(int i) {
return Lua.GET_OPCODE(i) == OP_JMP;
}
private static String append( String s, String t ) {
return (s==null? t: t==null? s: s+t);
}
private static String[] extractControlFlow( int[] code ) {
int n = code.length;
String[] s = new String[n];
int jmp;
for ( int pc=0; pc<n; pc++ ) {
int i = code[pc];
switch ( Lua.GET_OPCODE(i) ) {
// case OP_TFORLOOP:
// jmp = LuaState.GETARG_sBx(code[pc+1]);
// s[pc+jmp+1] = append( s[pc+jmp+1], "while (true) { /* TFORLOOP */ " );
// s[pc+1] = append( "} /* LOOPBOT */ ", s[pc+1] );
// break;
case OP_JMP:
jmp = LuaState.GETARG_sBx(code[pc]);
if ( jmp < 0 ) {
s[pc+jmp] = append( s[pc+jmp], "while (true) { /* WHILETOP */ " );
s[pc] = append( "} /* LOOPBOT */ ", s[pc] );
break;
} else {
// forward jump to end of loop is a break
int i2 = code[pc+jmp-1];
if ( Lua.GET_OPCODE(i2) == OP_JMP && LuaState.GETARG_sBx(i2) < 0 ) {
s[pc] = append( s[pc], "break " );
break;
}
// forward jump preceded by test is "if" block
if ( istest(code[pc-1]) ) {
s[pc] = append( s[pc], "{ /* IF */ " );
s[pc+jmp+1] = append( "} /* ENDIF */ ", s[pc+jmp+1] );
// end of block preceded by forward jump is else clause
i2 = code[pc+jmp];
int op2 = Lua.GET_OPCODE(i2);
int jmp2 = LuaState.GETARG_sBx(i2);
if ( op2 == OP_JMP && jmp2 > 0 ) {
s[pc+jmp+1] = append( s[pc+jmp+1], "else { /* ELSE */ " );
s[pc+jmp+jmp2+1] = append( "} /* ENDELSE */ ", s[pc+jmp+jmp2+1] );
}
}
}
break;
case OP_FORLOOP:
jmp = LuaState.GETARG_sBx(code[pc]);
s[pc+jmp] = append( s[pc+jmp], "{ /* FORTOP */ " );
s[pc] = append( "} /* LOOPBOT */ ", s[pc] );
break;
}
}
// find local variables, jump points
return s;
}
private static void writeSource( PrintStream ps, String name, LPrototype p ) {
int i, a, b, c, o, n, cb;
LValue rkb, rkc, nvarargs, key, val;
LValue i0, table;
boolean body;
int i, a, b, c, o;
String bs, cs;
int[] code = p.code;
LValue[] k = p.k;
String[] controlflow = extractControlFlow(code);
// class header
ps.print(
@@ -227,32 +246,6 @@ public class LuaJit extends Lua implements LuaCompiler {
}
ps.println();
// find local variables, jump points
int forlevel=0,maxforlevels=0;
int[] closes = new int[code.length];
boolean[] iselse = new boolean[code.length];
for ( int pc=0; pc<code.length; pc++ ) {
i = code[pc];
o = (i >> POS_OP) & MAX_OP;
switch (o) {
case OP_FORPREP:
maxforlevels = Math.max(maxforlevels, ++forlevel);
break;
case OP_FORLOOP:
forlevel--;
break;
case OP_JMP: {
int delta = LuaState.GETARG_sBx(code[pc]);
++closes[pc+delta+1];
if ( istest(code[pc-1]) && isjump(code[pc+delta]) )
iselse[pc+delta+1] = true;
break;
}
}
}
for ( int j=0; j<maxforlevels; j++ )
ps.println("\t\tboolean back"+j+";");
// loop until a return instruction is processed,
// or the vm yields
@@ -262,20 +255,12 @@ public class LuaJit extends Lua implements LuaCompiler {
ps.print( "\n\t\t// ");
Print.printOpCode(ps, p, pc);
ps.println();
if ( controlflow[pc] != null )
ps.println( "\t\t"+controlflow[pc] );
// get instruction
i = code[pc];
// close if-related jump bodies
if ( closes[pc]>0 || iselse[pc] ) {
ps.print("\t\t");
while ( closes[pc]-- > 0 )
ps.print("} ");
if ( iselse[pc] )
ps.print(" else {");
ps.println();
}
// get opcode and first arg
o = (i >> POS_OP) & MAX_OP;
a = (i >> POS_A) & MAXARG_A;
@@ -323,85 +308,42 @@ public class LuaJit extends Lua implements LuaCompiler {
break;
}
case LuaState.OP_GETGLOBAL: {
// b = LuaState.GETARG_Bx(i);
// key = k[b];
// table = cl.env;
// top = base + a;
// table.luaGetTable(this, table, key);
// pw.println("\t\tvm.top = base+"+a+";");
// continue
b = LuaState.GETARG_Bx(i);
ps.println("\t\tenv.luaGetTable(vm, env, k"+b+");");
ps.println("\t\ts"+a+" = vm.stack[--vm.top];");
ps.println("\t\ts"+a+" = vm.luaV_gettable(env, k"+b+");");
break;
}
case LuaState.OP_GETTABLE: {
//b = GETARG_B(i);
//key = GETARG_RKC(k, i);
//table = this.stack[base + b];
//top = base + a;
//table.luaGetTable(this, table, key);
//continue;
b = GETARG_B(i);
cs = GETARG_RKC_jit(i);
ps.println("\t\ts"+b+".luaGetTable(vm, s"+b+", "+cs+");");
ps.println("\t\ts"+a+" = vm.stack[--vm.top];");
ps.println("\t\ts"+a+" = vm.luaV_gettable(s"+b+", "+cs+");");
break;
}
case LuaState.OP_SETGLOBAL: {
//b = LuaState.GETARG_Bx(i);
//key = k[b];
//val = this.stack[base + a];
//table = cl.env;
//table.luaSetTable(this, table, key, val);
//continue;
b = LuaState.GETARG_Bx(i);
ps.println("\t\tenv.luaSetTable(vm, env, k"+b+", s"+a+");");
ps.println("\t\tvm.luaV_settable(env, k"+b+", s"+a+");");
break;
}
case LuaState.OP_SETUPVAL: {
//b = LuaState.GETARG_B(i);
//cl.upVals[b].setValue( this.stack[base + a] );
//continue;
b = LuaState.GETARG_B(i);
ps.println("\t\t\tjcl.upVals["+b+"].setValue(s"+a+");");
break;
}
case LuaState.OP_SETTABLE: {
//key = GETARG_RKB(k, i);
//val = GETARG_RKC(k, i);
//table = this.stack[base + a];
//table.luaSetTable(this, table, key, val);
//continue;
bs = GETARG_RKB_jit(i);
cs = GETARG_RKC_jit(i);
ps.println("\t\ts"+a+".luaSetTable(vm, s"+a+", "+bs+", "+cs+");");
ps.println("\t\tvm.luaV_settable(s"+a+", "+bs+", "+cs+");");
break;
}
case LuaState.OP_NEWTABLE: {
//b = LuaState.GETARG_B(i);
//c = LuaState.GETARG_C(i);
//this.stack[base + a] = new LTable(b, c);
//continue;
b = GETARG_B(i);
c = GETARG_C(i);
ps.println("\t\ts"+a+" = new LTable("+b+","+c+");");
break;
}
case LuaState.OP_SELF: {
//rkb = GETARG_RKB(k, i);
//rkc = GETARG_RKC(k, i);
//top = base + a;
//rkb.luaGetTable(this, rkb, rkc);
//this.stack[base + a + 1] = rkb;
//// StkId rb = RB(i);
//// setobjs2s(L, ra+1, rb);
//// Protect(luaV_gettable(L, rb, RKC(i), ra));
//continue;
bs = GETARG_RKB_jit(i);
cs = GETARG_RKC_jit(i);
ps.println("\t\t"+bs+".luaGetTable(vm, "+bs+", "+cs+");");
ps.println("\t\ts"+(a+1)+" = "+bs+";");
ps.println("\t\ts"+a+" = vm.luaV_gettable((s"+(a+1)+"="+bs+"), "+cs+");");
break;
}
case LuaState.OP_ADD:
@@ -420,25 +362,15 @@ public class LuaJit extends Lua implements LuaCompiler {
break;
}
case LuaState.OP_UNM: {
//rkb = GETARG_RKB(k, i);
//this.stack[base + a] = rkb.luaUnaryMinus();
//continue;
bs = GETARG_RKB_jit(i);
ps.println("\t\ts"+a+" = "+bs+".luaUnaryMinus();");
}
case LuaState.OP_NOT: {
//rkb = GETARG_RKB(k, i);
//this.stack[base + a] = (!rkb.toJavaBoolean() ? LBoolean.TRUE
// : LBoolean.FALSE);
//continue;
bs = GETARG_RKB_jit(i);
ps.println("\t\ts"+a+" = ("+bs+".toJavaBoolean()? LBoolean.TRUE: LBoolean.FALSE);");
break;
}
case LuaState.OP_LEN: {
//rkb = GETARG_RKB(k, i);
//this.stack[base + a] = LInteger.valueOf( rkb.luaLength() );
//continue;
bs = GETARG_RKB_jit(i);
ps.println("\t\ts"+a+" = LInteger.valueOf("+bs+".luaLength());");
}
@@ -484,7 +416,6 @@ public class LuaJit extends Lua implements LuaCompiler {
//continue;
c = LuaState.GETARG_C(i);
ps.println("\t\tif ( "+(c!=0?"!":"")+" s"+a+".toJavaBoolean())");
ps.println("\t\t{");
break;
}
/*
@@ -644,14 +575,14 @@ public class LuaJit extends Lua implements LuaCompiler {
String limit = "s"+(a+1);
String step = "s"+(a+2);
String idx = "s"+(a+3);
String back = "back"+(forlevel++);
String back = "back"+pc;
ps.println( "\t\tboolean "+back+";");
ps.println( "\t\tfor ( "+idx+"="+init+", "+back+"="+step+".luaBinCmpInteger(Lua.OP_LT,0);\n" +
"\t\t\t"+back+"? "+idx+".luaBinCmpUnknown(Lua.OP_LE, "+limit+"): "+limit+".luaBinCmpUnknown(Lua.OP_LE, "+idx+");\n" +
"\t\t\t"+idx+"="+idx+".luaBinOpUnknown(Lua.OP_ADD,"+step+") ) {");
break;
}
case LuaState.OP_FORLOOP: {
--forlevel;
ps.println( "\t\t}");
//i0 = this.stack[base + a];
//step = this.stack[base + a + 2];

View File

@@ -0,0 +1,53 @@
package org.luaj.jit;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import junit.framework.TestCase;
import org.luaj.compiler.LuaC;
import org.luaj.platform.J2sePlatform;
import org.luaj.vm.LClosure;
import org.luaj.vm.LPrototype;
import org.luaj.vm.LuaState;
import org.luaj.vm.Platform;
/**
* Simple test cases for lua jit basic functional test
*/
public class LuaJitBasicTest extends TestCase {
static {
Platform.setInstance(new J2sePlatform());
LuaC.install();
}
public void testPrintHelloWorld() throws IOException {
stringTest( "print( 'hello, world' )" );
}
public void testForLoop() throws IOException {
stringTest( "print 'starting'\n" +
"for i=1,3 do\n" +
" print( 'i', i )\n" +
"end");
}
private void stringTest(String program) throws IOException {
InputStream is = new ByteArrayInputStream(program.getBytes());
LPrototype p = LuaC.compile(is, "program");
run( p );
LPrototype q = LuaJit.jitCompile( p );
assertTrue(p!=q);
run( q );
}
private static void run(LPrototype p) {
LuaState vm = Platform.newLuaState();
LClosure c = p.newClosure(vm._G);
vm.pushlvalue(c);
vm.call(0, 0);
}
}