Refactor math support to provide more consistent and complete math library coverage

This commit is contained in:
James Roseborough
2008-07-21 22:12:06 +00:00
parent 467923b86e
commit 6642b38f91
9 changed files with 352 additions and 169 deletions

View File

@@ -28,39 +28,89 @@ import org.luaj.vm.LFunction;
import org.luaj.vm.LInteger;
import org.luaj.vm.LTable;
import org.luaj.vm.LValue;
import org.luaj.vm.Lua;
import org.luaj.vm.LuaState;
import org.luaj.vm.Platform;
public class MathLib extends LFunction {
public static final String[] NAMES = {
"math",
"abs",
"cos",
// irregular functions
"max",
"min",
"modf",
"sin",
"sqrt",
"ceil",
"floor",
"frexp",
"ldexp",
"random",
"randomseed",
// 2 argument, return double
"atan2",
"fmod",
"pow",
// single argument, return double
"abs",
"acos",
"asin",
"atan",
"cos",
"cosh",
"deg",
"exp",
"log",
"log10",
"rad",
"sin",
"sinh",
"sqrt",
"tan",
"tanh",
};
private static final int INSTALL = 0;
private static final int ABS = 1;
private static final int COS = 2;
private static final int MAX = 3;
private static final int MIN = 4;
private static final int MODF = 5;
private static final int SIN = 6;
private static final int SQRT = 7;
private static final int CEIL = 8;
private static final int FLOOR = 9;
private static final int RANDOM = 10;
private static final int RANDOMSEED = 11;
private static final int INSTALL = 0;
// irregular functions
public static final int MAX = 1;
public static final int MIN = 2;
public static final int MODF = 3;
public static final int CEIL = 4;
public static final int FLOOR = 5;
public static final int FREXP = 6;
public static final int LDEXP = 7;
public static final int RANDOM = 8;
public static final int RSEED = 9;
public static final int LAST_IRREGULAR = RSEED;
// 2 argument, return double
public static final int ATAN2 = 10;
public static final int FMOD = 11;
public static final int POW = 12;
public static final int LAST_DOUBLE_ARG = POW;
/* Math operations - single argument, one function */
public static final int ABS = 13;
public static final int ACOS = 14;
public static final int ASIN = 15;
public static final int ATAN = 16;
public static final int COS = 17;
public static final int COSH = 18;
public static final int DEG = 19;
public static final int EXP = 20;
public static final int LOG = 21;
public static final int LOG10 = 22;
public static final int RAD = 23;
public static final int SIN = 24;
public static final int SINH = 25;
public static final int SQRT = 26;
public static final int TAN = 27;
public static final int TANH = 28;
private static Platform platform;
public static void install( LTable globals ) {
LTable math = new LTable();
@@ -69,6 +119,7 @@ public class MathLib extends LFunction {
math.put( "huge", new LDouble( Double.MAX_VALUE ) );
math.put( "pi", new LDouble( Math.PI ) );
globals.put( "math", math );
platform = Platform.getInstance();
}
private static Random random = null;
@@ -83,11 +134,6 @@ public class MathLib extends LFunction {
return NAMES[id]+"()";
}
private static void setResult( LuaState vm, LValue value ) {
vm.resettop();
vm.pushlvalue( value );
}
private static void setResult( LuaState vm, double d ) {
vm.resettop();
vm.pushlvalue( LDouble.valueOf(d) );
@@ -99,86 +145,93 @@ public class MathLib extends LFunction {
}
public boolean luaStackCall( LuaState vm ) {
double x;
switch ( id ) {
case INSTALL:
install( vm._G );
break;
case ABS:
setResult( vm, Math.abs ( vm.checkdouble(2) ) );
break;
case COS:
setResult( vm, Math.cos ( vm.checkdouble(2) ) );
break;
case MAX: {
int n = vm.gettop();
x = vm.checkdouble(2);
for ( int i=3; i<=n; i++ )
x = Math.max(x, vm.checkdouble(i));
setResult( vm, x );
break;
}
case MIN: {
int n = vm.gettop();
x = vm.checkdouble(2);
for ( int i=3; i<=n; i++ )
x = Math.min(x, vm.checkdouble(i));
setResult(vm,x);
break;
}
case MODF: {
double v = vm.checkdouble(2);
double intPart = ( v > 0 ) ? Math.floor( v ) : Math.ceil( v );
double fracPart = v - intPart;
vm.resettop();
vm.pushnumber( intPart );
vm.pushnumber( fracPart );
break;
}
case SIN:
setResult( vm, Math.sin( vm.checkdouble(2) ) );
break;
case SQRT:
setResult( vm, Math.sqrt( vm.checkdouble(2) ) );
break;
case CEIL:
setResult( vm, (int) Math.ceil( vm.checkdouble(2) ) );
break;
case FLOOR:
setResult( vm, (int) Math.floor( vm.checkdouble(2) ) );
break;
case RANDOM: {
if ( random == null )
random = new Random();
switch ( vm.gettop() ) {
case 1:
vm.resettop();
vm.pushnumber(random.nextDouble());
if ( id > LAST_DOUBLE_ARG ) {
setResult( vm, platform.mathop(id, vm.checkdouble(2) ) );
} else if ( id > LAST_IRREGULAR ) {
setResult( vm, platform.mathop(id, vm.checkdouble(2), vm.checkdouble(3) ) );
} else {
switch ( id ) {
case INSTALL:
install( vm._G );
break;
case 2: {
int m = vm.checkint(2);
vm.argcheck(1<=m, 1, "interval is empty");
vm.resettop();
vm.pushinteger(1+random.nextInt(m));
case MAX: {
int n = vm.gettop();
double x = vm.checkdouble(2);
for ( int i=3; i<=n; i++ )
x = Math.max(x, vm.checkdouble(i));
setResult( vm, x );
break;
}
default: {
int m = vm.checkint(2);
int n = vm.checkint(3);
vm.argcheck(m<=n, 2, "interval is empty");
vm.resettop();
vm.pushinteger(m+random.nextInt(n+1-m));
case MIN: {
int n = vm.gettop();
double x = vm.checkdouble(2);
for ( int i=3; i<=n; i++ )
x = Math.min(x, vm.checkdouble(i));
setResult(vm,x);
break;
}
case MODF: {
double x = vm.checkdouble(2);
double intPart = ( x > 0 ) ? Math.floor( x ) : Math.ceil( x );
double fracPart = x - intPart;
vm.resettop();
vm.pushnumber( intPart );
vm.pushnumber( fracPart );
break;
}
case CEIL:
setResult( vm, (int) Math.ceil( vm.checkdouble(2) ) );
break;
case FLOOR:
setResult( vm, (int) Math.floor( vm.checkdouble(2) ) );
break;
case FREXP: {
long bits = Double.doubleToLongBits( vm.checkdouble(2) );
vm.resettop();
vm.pushnumber( ((bits & (~(-1L<<52))) + (1L<<52)) * ((bits >= 0)? (.5 / (1L<<52)): (-.5 / (1L<<52))) );
vm.pushinteger( (((int) (bits >> 52)) & 0x7ff) - 1022 );
break;
}
case LDEXP: {
double m = vm.checkdouble(2);
int e = vm.checkint(3);
vm.resettop();
vm.pushnumber( m * Double.longBitsToDouble(((long)(e+1023)) << 52) );
break;
}
case RANDOM: {
if ( random == null )
random = new Random();
switch ( vm.gettop() ) {
case 1:
vm.resettop();
vm.pushnumber(random.nextDouble());
break;
case 2: {
int m = vm.checkint(2);
vm.argcheck(1<=m, 1, "interval is empty");
vm.resettop();
vm.pushinteger(1+random.nextInt(m));
break;
}
default: {
int m = vm.checkint(2);
int n = vm.checkint(3);
vm.argcheck(m<=n, 2, "interval is empty");
vm.resettop();
vm.pushinteger(m+random.nextInt(n+1-m));
break;
}
}
break;
}
case RSEED:
random = new Random( vm.checkint(2) );
vm.resettop();
break;
default:
LuaState.vmerror( "bad math id" );
}
break;
}
case RANDOMSEED:
random = new Random( vm.checkint(2) );
vm.resettop();
break;
default:
LuaState.vmerror( "bad math id" );
}
return false;
}

View File

@@ -477,19 +477,19 @@ public class StringLib extends LFunction {
*/
static void sub( LuaState vm ) {
final LString s = vm.checklstring(2);
final int len = s.length();
final int l = s.length();
int i = posrelat( vm.checkint( 3 ), len );
int j = posrelat( vm.optint( 4, -1 ), len );
int start = posrelat( vm.checkint( 3 ), l );
int end = posrelat( vm.optint( 4, -1 ), l );
if ( i < 1 )
i = 1;
if ( j > len )
j = len;
if ( start < 1 )
start = 1;
if ( end > l )
end = l;
vm.resettop();
if ( i <= j ) {
LString result = s.substring( i - 1 , j );
if ( start <= end ) {
LString result = s.substring( start-1 , end );
vm.pushlstring( result );
} else {
vm.pushstring( "" );

View File

@@ -87,37 +87,12 @@ public class LDouble extends LNumber {
case Lua.OP_MUL: return new LDouble( lhs * rhs );
case Lua.OP_DIV: return new LDouble( lhs / rhs );
case Lua.OP_MOD: return new LDouble( lhs - Math.floor(lhs/rhs) * rhs );
case Lua.OP_POW: {
// allow platform to override math.pow()
LValue result = Platform.getInstance().mathPow(lhs, rhs);
return (result != null?
result:
new LDouble( dpow( lhs, rhs ) ));
}
case Lua.OP_POW: return Platform.getInstance().mathPow(lhs, rhs);
}
LuaState.vmerror( "bad bin opcode" );
return null;
}
public static double dpow(double a, double b) {
if ( b < 0 )
return 1 / dpow( a, -b );
double p = 1;
int whole = (int) b;
for ( double v=a; whole > 0; whole>>=1, v*=v )
if ( (whole & 1) != 0 )
p *= v;
if ( (b -= whole) > 0 ) {
int frac = (int) (0x10000 * b);
for ( ; (frac&0xffff)!=0; frac<<=1 ) {
a = Math.sqrt(a);
if ( (frac & 0x8000) != 0 )
p *= a;
}
}
return p;
}
public int toJavaInt() {
return (int) m_value;
}

View File

@@ -47,6 +47,7 @@ abstract public class Platform {
private static Platform instance;
/**
* Singleton to be used for platform operations.
*
@@ -177,4 +178,23 @@ abstract public class Platform {
}
return port;
}
/**
* Compute a math operation that takes a single double argument and returns a double
* @param id the math op, from MathLib constants
* @param x the arugment
* @return the value
* @throws LuaErrorException if the id is not supported by this platform.
*/
abstract public double mathop(int id, double x);
/**
* Compute a math operation that takes a two double arguments and returns a double
* @param id the math op, from MathLib constants
* @param x the first arugment
* @param y the second arugment
* @return the value
* @throws LuaErrorException if the id is not supported by this platform.
*/
abstract public double mathop(int id, double x, double y);
}

View File

@@ -8,8 +8,11 @@ import java.io.Reader;
import javax.microedition.midlet.MIDlet;
import org.luaj.debug.net.j2me.DebugSupportImpl;
import org.luaj.lib.MathLib;
import org.luaj.vm.DebugNetSupport;
import org.luaj.vm.LDouble;
import org.luaj.vm.LNumber;
import org.luaj.vm.LuaErrorException;
import org.luaj.vm.LuaState;
import org.luaj.vm.Platform;
@@ -50,6 +53,61 @@ public class J2meMidp10Cldc10Platform extends Platform {
}
public LNumber mathPow(double lhs, double rhs) {
throw new RuntimeException("mathPow(double lhs, double rhs) is not supported.");
return LDouble.valueOf(dpow(lhs,rhs));
}
public double mathop(int id, double a, double b) {
switch ( id ) {
case MathLib.ATAN2: return a==0? (b>0? Math.PI/2: b>0? -Math.PI/2: 0): Math.atan(b/a);
case MathLib.FMOD: return a - (b * ((int)(a/b)));
case MathLib.LDEXP: return a * dpow(2, b);
case MathLib.POW: return dpow(a, b);
}
throw new LuaErrorException( "unsupported math op" );
}
public double mathop(int id, double x) {
switch ( id ) {
case MathLib.ABS: return Math.abs(x);
//case MathLib.ACOS: return Math.acos(x);
//case MathLib.ASIN: return Math.asin(x);
//case MathLib.ATAN: return Math.atan(x);
case MathLib.COS: return Math.cos(x);
case MathLib.COSH: return (Math.exp(x) + Math.exp(-x)) / 2;
case MathLib.DEG: return Math.toDegrees(x);
case MathLib.EXP: return Math.exp(x);
case MathLib.LOG: return Math.log(x);
case MathLib.LOG10: return Math.log10(x);
case MathLib.RAD: return Math.toRadians(x);
case MathLib.SIN: return Math.sin(x);
case MathLib.SINH: return (Math.exp(x) - Math.exp(-x)) / 2;
case MathLib.SQRT: return Math.sqrt(x);
case MathLib.TAN: return Math.tan(x);
case MathLib.TANH: {
double e = Math.exp(2*x);
return (e-1) / (e+1);
}
}
throw new LuaErrorException( "unsupported math op" );
}
public static double dpow(double a, double b) {
if ( b < 0 )
return 1 / dpow( a, -b );
double p = 1;
int whole = (int) b;
for ( double v=a; whole > 0; whole>>=1, v*=v )
if ( (whole & 1) != 0 )
p *= v;
if ( (b -= whole) > 0 ) {
int frac = (int) (0x10000 * b);
for ( ; (frac&0xffff)!=0; frac<<=1 ) {
a = Math.sqrt(a);
if ( (frac & 0x8000) != 0 )
p *= a;
}
}
return p;
}
}

View File

@@ -9,10 +9,13 @@ import java.io.InputStreamReader;
import java.io.Reader;
import org.luaj.debug.net.j2se.DebugSupportImpl;
import org.luaj.lib.MathLib;
import org.luaj.lib.j2se.LuajavaLib;
import org.luaj.vm.DebugNetSupport;
import org.luaj.vm.LDouble;
import org.luaj.vm.LNumber;
import org.luaj.vm.LValue;
import org.luaj.vm.LuaErrorException;
import org.luaj.vm.LuaState;
import org.luaj.vm.Platform;
@@ -53,4 +56,36 @@ public class J2sePlatform extends Platform {
double d = Math.pow(lhs, rhs);
return LDouble.valueOf(d);
}
public double mathop(int id, double a, double b) {
switch ( id ) {
case MathLib.ATAN2: return Math.atan2(a, b);
case MathLib.FMOD: return a - (b * ((int)(a/b)));
case MathLib.LDEXP: return a * Math.pow(2, b);
case MathLib.POW: return Math.pow(a, b);
}
throw new LuaErrorException( "unsupported math op" );
}
public double mathop(int id, double x) {
switch ( id ) {
case MathLib.ABS: return Math.abs(x);
case MathLib.ACOS: return Math.acos(x);
case MathLib.ASIN: return Math.asin(x);
case MathLib.ATAN: return Math.atan(x);
case MathLib.COS: return Math.cos(x);
case MathLib.COSH: return Math.cosh(x);
case MathLib.DEG: return Math.toDegrees(x);
case MathLib.EXP: return Math.exp(x);
case MathLib.LOG: return Math.log(x);
case MathLib.LOG10: return Math.log10(x);
case MathLib.RAD: return Math.toRadians(x);
case MathLib.SIN: return Math.sin(x);
case MathLib.SINH: return Math.sinh(x);
case MathLib.SQRT: return Math.sqrt(x);
case MathLib.TAN: return Math.tan(x);
case MathLib.TANH: return Math.tanh(x);
}
throw new LuaErrorException( "unsupported math op" );
}
}

View File

@@ -5,7 +5,6 @@ local fail = 'fail '
local needcheck = 'needcheck '
local badmsg = 'badmsg '
akey = 'aa'
astring = 'abc'
astrnum = '789'
@@ -62,7 +61,7 @@ end
local function ellipses(v)
local s = tostring(v)
return #s <= 8 and s or string.sub(s,8)..'...'
return #s <= 8 and s or (string.sub(s,1,8)..'...')
end
local pretty = {
@@ -84,6 +83,15 @@ local function values(list)
return table.concat(t,',')
end
local function types(list)
local t = {}
for i=1,#list do
local ai = list[i]
t[i] = type(ai)
end
return table.concat(t,',')
end
local function signature(name,arglist)
return name..'('..values(arglist)..')'
end
@@ -145,14 +153,18 @@ local function subbanner(name)
end
-- check that all combinations of arguments pass
function checkallpass( name, typesets )
function checkallpass( name, typesets, typesonly )
subbanner('checkallpass')
for i,v in arglists(typesets) do
local sig = signature(name,v)
local r = { invoke( name, v ) }
local s = table.remove( r, 1 )
if s then
print( ok, sig, values(r) )
if typesonly then
print( ok, sig, types(r) )
else
print( ok, sig, values(r) )
end
else
print( fail, sig, values(r) )
end

View File

@@ -2,17 +2,26 @@ package.path = "?.lua;src/test/errors/?.lua"
require 'args'
-- arg type tests for math library functions
local somenumber = {23,45.67,'-12','-345.678'}
local somenumber = {1,0.75,'-1','-0.25'}
local somepositive = {1,0.75,'2', '2.5'}
local notanumber = {nil,astring,aboolean,afunction,atable,athread}
local nonnumber = {astring,aboolean,afunction,atable}
local singleargfunctions = {
'abs', 'acos', 'asin', 'atan', 'ceil', 'cos', 'cosh', 'deg', 'exp', 'floor', 'frexp',
'log', 'log10', 'rad', 'randomseed', 'sin', 'sinh', 'sqrt', 'tan', 'tanh',
}
'abs', 'acos', 'asin', 'atan', 'ceil', 'cos', 'cosh', 'deg', 'exp', 'floor',
'rad', 'randomseed', 'sin', 'sinh', 'tan', 'tanh', 'frexp',
}
local singleargposdomain = {
'log', 'log10', 'sqrt',
}
local twoargfunctions = {
'atan2', 'fmod', 'pow',
'atan2', 'fmod',
}
local twoargsposdomain = {
'pow',
}
-- single argument tests
@@ -23,6 +32,14 @@ for i,v in ipairs(singleargfunctions) do
checkallerrors(funcname,{notanumber},'bad argument #1')
end
-- single argument, positive domain tests
for i,v in ipairs(singleargposdomain) do
local funcname = 'math.'..v
banner(funcname)
checkallpass(funcname,{somepositive})
checkallerrors(funcname,{notanumber},'bad argument #1')
end
-- two-argument tests
for i,v in ipairs(twoargfunctions) do
local funcname = 'math.'..v
@@ -35,6 +52,18 @@ for i,v in ipairs(twoargfunctions) do
checkallerrors(funcname,{somenumber,notanumber},'bad argument #2')
end
-- two-argument, positive domain tests
for i,v in ipairs(twoargsposdomain) do
local funcname = 'math.'..v
banner(funcname)
checkallpass(funcname,{somepositive,somenumber})
checkallerrors(funcname,{},'bad argument #1')
checkallerrors(funcname,{notanumber},'bad argument #1')
checkallerrors(funcname,{notanumber,somenumber},'bad argument #1')
checkallerrors(funcname,{somenumber},'bad argument #2')
checkallerrors(funcname,{somenumber,notanumber},'bad argument #2')
end
-- math.max
banner('math.max')
checkallpass('math.max',{somenumber})
@@ -56,11 +85,10 @@ local somem = {3,4.5,'6.7'}
local somen = {8,9.10,'12.34'}
local notamn = {astring,aboolean,atable,afunction}
banner('math.random')
checkallpass('math.random',{})
checkallpass('math.random',{somem})
checkallpass('math.random',{somem,somen})
checkallpass('math.random',{{8},{7.8}})
checkallpass('math.random',{{-4,-5.6,'-7','-8.9'},{-1,100,23.45,'-1.23'}})
checkallpass('math.random',{},true)
checkallpass('math.random',{somem},true)
checkallpass('math.random',{somem,somen},true)
checkallpass('math.random',{{-4,-5.6,'-7','-8.9'},{-1,100,23.45,'-1.23'}},true)
checkallerrors('math.random',{{-4,-5.6,'-7','-8.9'}},'interval is empty')
checkallerrors('math.random',{somen,somem},'interval is empty')
checkallerrors('math.random',{notamn,somen},'bad argument #1')
@@ -71,8 +99,8 @@ local somee = {-3,0,3,9.10,'12.34'}
local notae = {nil,astring,aboolean,atable,afunction}
banner('math.ldexp')
checkallpass('math.ldexp',{somenumber,somee})
checkallerrors('math.ldexp',{},'bad argument #2')
checkallerrors('math.ldexp',{notanumber},'bad argument #2')
checkallerrors('math.ldexp',{},'bad argument')
checkallerrors('math.ldexp',{notanumber},'bad argument')
checkallerrors('math.ldexp',{notanumber,somee},'bad argument #1')
checkallerrors('math.ldexp',{somenumber},'bad argument #2')
checkallerrors('math.ldexp',{somenumber,notae},'bad argument #2')

View File

@@ -2,28 +2,30 @@ package org.luaj.vm;
import junit.framework.TestCase;
import org.luaj.platform.J2meMidp10Cldc10Platform;
public class MathLibTest extends TestCase {
public void testMathDPow() {
assertEquals( 1, LDouble.dpow(2, 0), 0 );
assertEquals( 2, LDouble.dpow(2, 1), 0 );
assertEquals( 8, LDouble.dpow(2, 3), 0 );
assertEquals( -8, LDouble.dpow(-2, 3), 0 );
assertEquals( 1/8., LDouble.dpow(2, -3), 0 );
assertEquals( -1/8., LDouble.dpow(-2, -3), 0 );
assertEquals( 16, LDouble.dpow(256, .5), 0 );
assertEquals( 4, LDouble.dpow(256, .25), 0 );
assertEquals( 64, LDouble.dpow(256, .75), 0 );
assertEquals( 1./16, LDouble.dpow(256, - .5), 0 );
assertEquals( 1./ 4, LDouble.dpow(256, -.25), 0 );
assertEquals( 1./64, LDouble.dpow(256, -.75), 0 );
assertEquals( Double.NaN, LDouble.dpow(-256, .5), 0 );
assertEquals( 1, LDouble.dpow(.5, 0), 0 );
assertEquals( .5, LDouble.dpow(.5, 1), 0 );
assertEquals(.125, LDouble.dpow(.5, 3), 0 );
assertEquals( 2, LDouble.dpow(.5, -1), 0 );
assertEquals( 8, LDouble.dpow(.5, -3), 0 );
assertEquals(1, LDouble.dpow(0.0625, 0), 0 );
assertEquals(0.00048828125, LDouble.dpow(0.0625, 2.75), 0 );
assertEquals( 1, J2meMidp10Cldc10Platform.dpow(2, 0), 0 );
assertEquals( 2, J2meMidp10Cldc10Platform.dpow(2, 1), 0 );
assertEquals( 8, J2meMidp10Cldc10Platform.dpow(2, 3), 0 );
assertEquals( -8, J2meMidp10Cldc10Platform.dpow(-2, 3), 0 );
assertEquals( 1/8., J2meMidp10Cldc10Platform.dpow(2, -3), 0 );
assertEquals( -1/8., J2meMidp10Cldc10Platform.dpow(-2, -3), 0 );
assertEquals( 16, J2meMidp10Cldc10Platform.dpow(256, .5), 0 );
assertEquals( 4, J2meMidp10Cldc10Platform.dpow(256, .25), 0 );
assertEquals( 64, J2meMidp10Cldc10Platform.dpow(256, .75), 0 );
assertEquals( 1./16, J2meMidp10Cldc10Platform.dpow(256, - .5), 0 );
assertEquals( 1./ 4, J2meMidp10Cldc10Platform.dpow(256, -.25), 0 );
assertEquals( 1./64, J2meMidp10Cldc10Platform.dpow(256, -.75), 0 );
assertEquals( Double.NaN, J2meMidp10Cldc10Platform.dpow(-256, .5), 0 );
assertEquals( 1, J2meMidp10Cldc10Platform.dpow(.5, 0), 0 );
assertEquals( .5, J2meMidp10Cldc10Platform.dpow(.5, 1), 0 );
assertEquals(.125, J2meMidp10Cldc10Platform.dpow(.5, 3), 0 );
assertEquals( 2, J2meMidp10Cldc10Platform.dpow(.5, -1), 0 );
assertEquals( 8, J2meMidp10Cldc10Platform.dpow(.5, -3), 0 );
assertEquals(1, J2meMidp10Cldc10Platform.dpow(0.0625, 0), 0 );
assertEquals(0.00048828125, J2meMidp10Cldc10Platform.dpow(0.0625, 2.75), 0 );
}
}