diff --git a/jse/src/main/java/org/luaj/vm2/script/LuaScriptEngine.java b/jse/src/main/java/org/luaj/vm2/script/LuaScriptEngine.java index 297f480d..1ea0a579 100644 --- a/jse/src/main/java/org/luaj/vm2/script/LuaScriptEngine.java +++ b/jse/src/main/java/org/luaj/vm2/script/LuaScriptEngine.java @@ -22,6 +22,9 @@ package org.luaj.vm2.script; import java.io.*; +import java.lang.reflect.InvocationHandler; +import java.lang.reflect.Method; +import java.lang.reflect.Proxy; import javax.script.*; @@ -29,6 +32,7 @@ import org.luaj.vm2.*; import org.luaj.vm2.libs.ThreeArgFunction; import org.luaj.vm2.libs.TwoArgFunction; import org.luaj.vm2.libs.jse.CoerceJavaToLua; +import org.luaj.vm2.libs.jse.CoerceLuaToJava; /** * Implementation of the ScriptEngine interface which can compile and execute @@ -41,7 +45,7 @@ import org.luaj.vm2.libs.jse.CoerceJavaToLua; * and for client bindings use the default engine scoped bindings or * construct a {@link LuajBindings} directly. */ -public class LuaScriptEngine extends AbstractScriptEngine implements ScriptEngine, Compilable { +public class LuaScriptEngine extends AbstractScriptEngine implements ScriptEngine, Compilable, Invocable { private static final String __ENGINE_VERSION__ = Lua._VERSION; private static final String __NAME__ = "Luaj"; @@ -132,6 +136,133 @@ public class LuaScriptEngine extends AbstractScriptEngine implements ScriptEngin return myFactory; } + @Override + public Object invokeFunction(String name, Object... args) throws ScriptException, NoSuchMethodException { + LuaValue function = getInvocableValue(currentGlobals(), name); + return invokeValue(function, toLuaVarargs(args)); + } + + @Override + public Object invokeMethod(Object thiz, String name, Object... args) throws ScriptException, NoSuchMethodException { + LuaValue target = thiz instanceof LuaValue ? (LuaValue) thiz : toLua(thiz); + LuaValue function = target.get(name); + if (!function.isfunction()) { + throw new NoSuchMethodException(name + " on " + target.typename()); + } + return invokeValue(function, prependSelf(target, args)); + } + + @Override + public T getInterface(Class clasz) { + return getInterface((Object) currentGlobals(), clasz); + } + + @Override + public T getInterface(Object thiz, Class clasz) { + if (thiz == null || clasz == null || !clasz.isInterface()) { + return null; + } + LuaValue target = thiz instanceof LuaValue ? (LuaValue) thiz : toLua(thiz); + Globals globals = currentGlobals(); + return hasAllInterfaceMethods(target, clasz, thiz == globals) ? createInterfaceProxy(target, clasz, thiz != globals) : null; + } + + private LuaValue getInvocableValue(Globals globals, String name) throws NoSuchMethodException { + prepareBindings(globals, getContext().getBindings(ScriptContext.ENGINE_SCOPE)); + LuaValue function = globals.get(name); + if (!function.isfunction()) { + throw new NoSuchMethodException(name); + } + return function; + } + + private Object invokeValue(LuaValue function, Varargs args) throws ScriptException { + try { + return toJava(function.invoke(args)); + } catch (LuaError e) { + throw scriptException(e); + } + } + + private static ScriptException scriptException(LuaError e) { + ScriptException se = new ScriptException(e.getMessage()); + se.initCause(e); + return se; + } + + private void prepareBindings(Globals globals, Bindings bindings) { + globals.setmetatable(new BindingsMetatable(bindings)); + } + + private Globals currentGlobals() { + return ((LuajContext) getContext()).globals; + } + + private static LuaValue[] toLuaArgs(Object[] args) { + if (args == null || args.length == 0) { + return new LuaValue[0]; + } + LuaValue[] values = new LuaValue[args.length]; + for (int i = 0; i < args.length; ++i) { + values[i] = toLua(args[i]); + } + return values; + } + + private static Varargs toLuaVarargs(Object[] args) { + return LuaValue.varargsOf(toLuaArgs(args)); + } + + private static Varargs prependSelf(LuaValue target, Object[] args) { + return LuaValue.varargsOf(target, toLuaVarargs(args)); + } + + private boolean hasAllInterfaceMethods(LuaValue target, Class clasz, boolean globalTarget) { + if (globalTarget) { + prepareBindings(currentGlobals(), getContext().getBindings(ScriptContext.ENGINE_SCOPE)); + } + Method[] methods = clasz.getMethods(); + for (int i = 0; i < methods.length; ++i) { + Method method = methods[i]; + if (method.getDeclaringClass() == Object.class) { + continue; + } + if (!target.get(method.getName()).isfunction()) { + return false; + } + } + return true; + } + + private T createInterfaceProxy(final LuaValue target, final Class clasz, final boolean methodStyle) { + return clasz.cast(Proxy.newProxyInstance( + clasz.getClassLoader(), + new Class[] { clasz }, + new InvocationHandler() { + public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { + if (method.getDeclaringClass() == Object.class) { + String name = method.getName(); + if ("toString".equals(name)) { + return "LuajInterfaceProxy(" + clasz.getName() + ")"; + } + if ("hashCode".equals(name)) { + return Integer.valueOf(System.identityHashCode(proxy)); + } + if ("equals".equals(name)) { + return Boolean.valueOf(proxy == args[0]); + } + } + LuaValue function = target.get(method.getName()); + if (!function.isfunction()) { + throw new NoSuchMethodException(method.getName()); + } + Varargs result = methodStyle ? target.invokemethod(method.getName(), toLuaVarargs(args)) : function.invoke(toLuaVarargs(args)); + LuaValue value = result.arg1(); + return CoerceLuaToJava.coerce(value, method.getReturnType()); + } + })); + } + class LuajCompiledScript extends CompiledScript { final LuaFunction function; @@ -158,7 +289,7 @@ public class LuaScriptEngine extends AbstractScriptEngine implements ScriptEngin } Object eval(Globals g, Bindings b) throws ScriptException { - g.setmetatable(new BindingsMetatable(b)); + prepareBindings(g, b); LuaFunction f = function; if (f.isclosure()) f = new LuaClosure(f.checkclosure().p, g); diff --git a/jse/src/test/java/org/luaj/vm2/script/ScriptEngineTests.java b/jse/src/test/java/org/luaj/vm2/script/ScriptEngineTests.java index 88e63497..a85cb2f6 100644 --- a/jse/src/test/java/org/luaj/vm2/script/ScriptEngineTests.java +++ b/jse/src/test/java/org/luaj/vm2/script/ScriptEngineTests.java @@ -28,6 +28,7 @@ import java.io.Reader; import javax.script.Bindings; import javax.script.Compilable; import javax.script.CompiledScript; +import javax.script.Invocable; import javax.script.ScriptContext; import javax.script.ScriptEngine; import javax.script.ScriptEngineFactory; @@ -53,6 +54,7 @@ public class ScriptEngineTests extends TestSuite { suite.addTest( new TestSuite( CompileNonClosureTest.class, "Compile NonClosure" ) ); suite.addTest( new TestSuite( UserContextTest.class, "User Context" ) ); suite.addTest( new TestSuite( WriterTest.class, "Writer" ) ); + suite.addTest( new TestSuite( InvocableTest.class, "Invocable" ) ); return suite; } @@ -308,4 +310,78 @@ public class ScriptEngineTests extends TestSuite { output.reset(); } } + + public interface Adder { + int add(int x, int y); + } + + public static class InvocableTest extends TestCase { + private ScriptEngine e; + private Invocable inv; + + public void setUp() { + this.e = new ScriptEngineManager().getEngineByName("luaj"); + this.inv = (Invocable) e; + } + + public void testInvokeFunction() throws Exception { + e.eval("function add(x, y) return x + y end"); + assertEquals(7, inv.invokeFunction("add", 3, 4)); + } + + public void testInvokeMethod() throws Exception { + Object table = e.eval("return { add = function(self, x, y) return x + y end }"); + assertEquals(9, inv.invokeMethod(table, "add", 4, 5)); + } + + public void testInvokeFunctionMissingThrowsNoSuchMethod() throws Exception { + try { + inv.invokeFunction("missing"); + fail("expected NoSuchMethodException"); + } catch (NoSuchMethodException e) { + assertEquals("missing", e.getMessage()); + } + } + + public void testInvokeMethodMissingThrowsNoSuchMethod() throws Exception { + Object table = e.eval("return {}"); + try { + inv.invokeMethod(table, "missing"); + fail("expected NoSuchMethodException"); + } catch (NoSuchMethodException e) { + assertEquals("missing on table", e.getMessage()); + } + } + + public void testInvokeFunctionRuntimeErrorHasCause() throws Exception { + e.eval("function explode() error('boom') end"); + try { + inv.invokeFunction("explode"); + fail("expected ScriptException"); + } catch (ScriptException e) { + assertEquals("boom", e.getMessage()); + assertNotNull(e.getCause()); + } + } + + public void testGetInterfaceFromGlobals() throws Exception { + e.eval("function add(x, y) return x + y end"); + Adder adder = inv.getInterface(Adder.class); + assertNotNull(adder); + assertEquals(11, adder.add(5, 6)); + } + + public void testGetInterfaceFromTable() throws Exception { + Object table = e.eval("local t = {} function t:add(x, y) return x + y end return t"); + Adder adder = inv.getInterface(table, Adder.class); + assertNotNull(adder); + assertEquals(13, adder.add(6, 7)); + } + + public void testGetInterfaceReturnsNullWhenMethodMissing() throws Exception { + assertNull(inv.getInterface(Adder.class)); + Object table = e.eval("return {}"); + assertNull(inv.getInterface(table, Adder.class)); + } + } }