Add arithmetic metatag processing.

This commit is contained in:
James Roseborough
2010-08-18 18:55:12 +00:00
parent fdeb392205
commit 7958ee7109
3 changed files with 150 additions and 58 deletions

View File

@@ -28,6 +28,7 @@ import java.io.InputStream;
import java.lang.ref.WeakReference; import java.lang.ref.WeakReference;
import java.util.Hashtable; import java.util.Hashtable;
import org.luaj.vm2.lib.MathLib;
import org.luaj.vm2.lib.StringLib; import org.luaj.vm2.lib.StringLib;
public class LuaString extends LuaValue { public class LuaString extends LuaValue {
@@ -108,23 +109,32 @@ public class LuaString extends LuaValue {
} }
// unary operators // unary operators
public LuaValue neg() { return checkarith().neg(); } public LuaValue neg() { return valueOf(-checkarith()); }
// basic binary arithmetic // basic binary arithmetic
public LuaValue add( LuaValue rhs ) { return checkarith().add(rhs); } public LuaValue add( LuaValue rhs ) { double d = scannumber(10); return Double.isNaN(d)? arithmt(ADD,rhs): rhs.add(d); }
public LuaValue add( double lhs ) { return checkarith().add(lhs); } public LuaValue add( double rhs ) { return valueOf( checkarith() + rhs ); }
public LuaValue sub( LuaValue rhs ) { return checkarith().sub(rhs); } public LuaValue add( int rhs ) { return valueOf( checkarith() + rhs ); }
public LuaValue subFrom( double lhs ) { return checkarith().subFrom(lhs); } public LuaValue sub( LuaValue rhs ) { double d = scannumber(10); return Double.isNaN(d)? arithmt(SUB,rhs): rhs.subFrom(d); }
public LuaValue mul( LuaValue rhs ) { return checkarith().mul(rhs); } public LuaValue sub( double rhs ) { return valueOf( checkarith() - rhs ); }
public LuaValue mul( double lhs ) { return checkarith().mul(lhs); } public LuaValue sub( int rhs ) { return valueOf( checkarith() - rhs ); }
public LuaValue mul( int lhs ) { return checkarith().mul(lhs); } public LuaValue subFrom( double lhs ) { return valueOf( lhs - checkarith() ); }
public LuaValue pow( LuaValue rhs ) { return checkarith().pow(rhs); } public LuaValue mul( LuaValue rhs ) { double d = scannumber(10); return Double.isNaN(d)? arithmt(MUL,rhs): rhs.mul(d); }
public LuaValue powWith( double lhs ) { return checkarith().powWith(lhs); } public LuaValue mul( double rhs ) { return valueOf( checkarith() * rhs ); }
public LuaValue powWith( int lhs ) { return checkarith().powWith(lhs); } public LuaValue mul( int rhs ) { return valueOf( checkarith() * rhs ); }
public LuaValue div( LuaValue rhs ) { return checkarith().div(rhs); } public LuaValue pow( LuaValue rhs ) { double d = scannumber(10); return Double.isNaN(d)? arithmt(POW,rhs): rhs.powWith(d); }
public LuaValue divInto( double lhs ) { return checkarith().divInto(lhs); } public LuaValue pow( double rhs ) { return MathLib.dpow(checkarith(),rhs); }
public LuaValue mod( LuaValue rhs ) { return checkarith().mod(rhs); } public LuaValue pow( int rhs ) { return MathLib.dpow(checkarith(),rhs); }
public LuaValue modFrom( double lhs ) { return checkarith().modFrom(lhs); } public LuaValue powWith( double lhs ) { return MathLib.dpow(lhs, checkarith()); }
public LuaValue powWith( int lhs ) { return MathLib.dpow(lhs, checkarith()); }
public LuaValue div( LuaValue rhs ) { double d = scannumber(10); return Double.isNaN(d)? arithmt(DIV,rhs): rhs.divInto(d); }
public LuaValue div( double rhs ) { return LuaDouble.ddiv(checkarith(),rhs); }
public LuaValue div( int rhs ) { return LuaDouble.ddiv(checkarith(),rhs); }
public LuaValue divInto( double lhs ) { return LuaDouble.ddiv(lhs, checkarith()); }
public LuaValue mod( LuaValue rhs ) { double d = scannumber(10); return Double.isNaN(d)? arithmt(MOD,rhs): rhs.modFrom(d); }
public LuaValue mod( double rhs ) { return LuaDouble.dmod(checkarith(), rhs); }
public LuaValue mod( int rhs ) { return LuaDouble.dmod(checkarith(), rhs); }
public LuaValue modFrom( double lhs ) { return LuaDouble.dmod(lhs, checkarith()); }
// relational operators, these only work with other strings // relational operators, these only work with other strings
public LuaValue lt( LuaValue rhs ) { return rhs.strcmp(this)>0? LuaValue.TRUE: FALSE; } public LuaValue lt( LuaValue rhs ) { return rhs.strcmp(this)>0? LuaValue.TRUE: FALSE; }
@@ -160,56 +170,67 @@ public class LuaString extends LuaValue {
} }
/** Check for number in arithmetic, or throw aritherror */ /** Check for number in arithmetic, or throw aritherror */
private LuaValue checkarith() { private double checkarith() {
LuaValue v = tonumber(10); double d = scannumber(10);
return v.isnil()? aritherror(): v; if ( Double.isNaN(d) )
aritherror();
return d;
} }
public int checkint() { public int checkint() {
return checknumber().toint(); return (int) (long) checkdouble();
} }
public LuaInteger checkinteger() { public LuaInteger checkinteger() {
return checknumber().checkinteger(); return valueOf(checkint());
} }
public long checklong() { public long checklong() {
return checknumber().tolong(); return (long) checkdouble();
} }
public double checkdouble() { public double checkdouble() {
return checknumber().todouble(); double d = scannumber(10);
if ( Double.isNaN(d) )
argerror("number");
return d;
} }
public LuaNumber checknumber() { public LuaNumber checknumber() {
LuaValue n = tonumber(10); return valueOf(checkdouble());
if ( ! n.isnumber() )
argerror("number");
return (LuaNumber) n;
} }
public LuaNumber checknumber(String msg) { public LuaNumber checknumber(String msg) {
LuaValue n = tonumber(10); double d = scannumber(10);
if ( ! n.isnumber() ) if ( Double.isNaN(d) )
argerror(msg); argerror("number");
return (LuaNumber) n; return valueOf(d);
} }
public LuaValue tonumber() { public LuaValue tonumber() {
return tonumber(10); return tonumber(10);
} }
public boolean isnumber() { public boolean isnumber() {
return ! tonumber(10).isnil(); double d = scannumber(10);
return ! Double.isNaN(d);
} }
public boolean isint() { public boolean isint() {
return tonumber(10).isint(); double d = scannumber(10);
if ( Double.isNaN(d) )
return false;
int i = (int) d;
return i == d;
} }
public boolean islong() { public boolean islong() {
return tonumber(10).islong(); double d = scannumber(10);
if ( Double.isNaN(d) )
return false;
long l = (long) d;
return l == d;
} }
public byte tobyte() { return (byte) toint(); } public byte tobyte() { return (byte) toint(); }
public char tochar() { return (char) toint(); } public char tochar() { return (char) toint(); }
public double todouble() { LuaValue n=tonumber(10); return n.isnil()? 0: n.todouble(); } public double todouble() { double d=scannumber(10); return Double.isNaN(d)? 0: d; }
public float tofloat() { return (float) todouble(); } public float tofloat() { return (float) todouble(); }
public int toint() { LuaValue n=tonumber(10); return n.isnil()? 0: n.toint(); } public int toint() { return (int) tolong(); }
public long tolong() { return (long) todouble(); } public long tolong() { return (long) todouble(); }
public short toshort() { return (short) toint(); } public short toshort() { return (short) toint(); }
@@ -467,44 +488,53 @@ public class LuaString extends LuaValue {
* @return IntValue, DoubleValue, or NIL depending on the content of the string. * @return IntValue, DoubleValue, or NIL depending on the content of the string.
*/ */
public LuaValue tonumber( int base ) { public LuaValue tonumber( int base ) {
double d = scannumber( base );
return Double.isNaN(d)? NIL: valueOf(d);
}
/**
* Convert to a number in a base, or return Double.NaN if not a number.
*/
public double scannumber( int base ) {
if ( base >= 2 && base <= 36 ) { if ( base >= 2 && base <= 36 ) {
int i=m_offset,j=m_offset+m_length; int i=m_offset,j=m_offset+m_length;
while ( i<j && m_bytes[i]==' ' ) ++i; while ( i<j && m_bytes[i]==' ' ) ++i;
while ( i<j && m_bytes[j-1]==' ' ) --j; while ( i<j && m_bytes[j-1]==' ' ) --j;
if ( i>=j ) return FALSE; if ( i>=j )
return Double.NaN;
if ( ( base == 10 || base == 16 ) && ( m_bytes[i]=='0' && i+1<j && (m_bytes[i+1]=='x'||m_bytes[i+1]=='X') ) ) { if ( ( base == 10 || base == 16 ) && ( m_bytes[i]=='0' && i+1<j && (m_bytes[i+1]=='x'||m_bytes[i+1]=='X') ) ) {
base = 16; base = 16;
i+=2; i+=2;
} }
LuaValue l = scanlong( base, i, j ); double l = scanlong( base, i, j );
return l!=NIL? l: base==10? scandouble(i,j): NIL; return Double.isNaN(l) && base==10? scandouble(i,j): l;
} }
return NIL; return Double.NaN;
} }
/** /**
* Scan and convert a long value, or return NIL if not found. * Scan and convert a long value, or return Double.NaN if not found.
* @return DoubleValue, IntValue, or NIL depending on what is found. * @return DoubleValue, IntValue, or Double.NaN depending on what is found.
*/ */
private LuaValue scanlong( int base, int start, int end ) { private double scanlong( int base, int start, int end ) {
long x = 0; long x = 0;
boolean neg = (m_bytes[start] == '-'); boolean neg = (m_bytes[start] == '-');
for ( int i=(neg?start+1:start); i<end; i++ ) { for ( int i=(neg?start+1:start); i<end; i++ ) {
int digit = m_bytes[i] - (base<=10||(m_bytes[i]>='0'&&m_bytes[i]<='9')? '0': int digit = m_bytes[i] - (base<=10||(m_bytes[i]>='0'&&m_bytes[i]<='9')? '0':
m_bytes[i]>='A'&&m_bytes[i]<='Z'? ('A'-10): ('a'-10)); m_bytes[i]>='A'&&m_bytes[i]<='Z'? ('A'-10): ('a'-10));
if ( digit < 0 || digit >= base ) if ( digit < 0 || digit >= base )
return NIL; return Double.NaN;
x = x * base + digit; x = x * base + digit;
} }
return valueOf(neg? -x: x); return neg? -x: x;
} }
/** /**
* Scan and convert a double value, or return NIL if not a double. * Scan and convert a double value, or return Double.NaN if not a double.
* @return DoubleValue, IntValue, or NIL depending on what is found. * @return DoubleValue, IntValue, or Double.NaN depending on what is found.
*/ */
private LuaValue scandouble(int start, int end) { private double scandouble(int start, int end) {
if ( end>start+64 ) end=start+64; if ( end>start+64 ) end=start+64;
for ( int i=start; i<end; i++ ) { for ( int i=start; i<end; i++ ) {
switch ( m_bytes[i] ) { switch ( m_bytes[i] ) {
@@ -516,17 +546,17 @@ public class LuaString extends LuaValue {
case '5': case '6': case '7': case '8': case '9': case '5': case '6': case '7': case '8': case '9':
break; break;
default: default:
return NIL; return Double.NaN;
} }
} }
char [] c = new char[end-start]; char [] c = new char[end-start];
for ( int i=start; i<end; i++ ) for ( int i=start; i<end; i++ )
c[i-start] = (char) m_bytes[i]; c[i-start] = (char) m_bytes[i];
try { try {
return valueOf( Double.parseDouble(new String(c))); return Double.parseDouble(new String(c));
} catch ( Exception e ) { } catch ( Exception e ) {
return Double.NaN;
} }
return NIL;
} }
} }

View File

@@ -66,6 +66,12 @@ public class LuaValue extends Varargs {
public static final LuaString CALL = valueOf("__call"); public static final LuaString CALL = valueOf("__call");
public static final LuaString MODE = valueOf("__mode"); public static final LuaString MODE = valueOf("__mode");
public static final LuaString METATABLE = valueOf("__metatable"); public static final LuaString METATABLE = valueOf("__metatable");
public static final LuaString ADD = valueOf("__add");
public static final LuaString SUB = valueOf("__sub");
public static final LuaString DIV = valueOf("__div");
public static final LuaString MUL = valueOf("__mul");
public static final LuaString POW = valueOf("__pow");
public static final LuaString MOD = valueOf("__mod");
public static final LuaString EMPTYSTRING = valueOf(""); public static final LuaString EMPTYSTRING = valueOf("");
private static int MAXSTACK = 250; private static int MAXSTACK = 250;
@@ -261,30 +267,36 @@ public class LuaValue extends Varargs {
public boolean neq_b( int val ) { return ! eq_b(val); } public boolean neq_b( int val ) { return ! eq_b(val); }
// arithmetic operators // arithmetic operators
public LuaValue add( LuaValue rhs ) { return aritherror("add"); } public LuaValue add( LuaValue rhs ) { return arithmt(ADD,rhs); }
public LuaValue add(double rhs) { return aritherror("add"); } public LuaValue add(double rhs) { return aritherror("add"); }
public LuaValue add(int rhs) { return add((double)rhs); } public LuaValue add(int rhs) { return add((double)rhs); }
public LuaValue sub( LuaValue rhs ) { return aritherror("sub"); } public LuaValue sub( LuaValue rhs ) { return arithmt(SUB,rhs); }
public LuaValue sub( double rhs ) { return aritherror("sub"); } public LuaValue sub( double rhs ) { return aritherror("sub"); }
public LuaValue sub( int rhs ) { return aritherror("sub"); } public LuaValue sub( int rhs ) { return aritherror("sub"); }
public LuaValue subFrom(double lhs) { return aritherror("sub"); } public LuaValue subFrom(double lhs) { return aritherror("sub"); }
public LuaValue subFrom(int lhs) { return subFrom((double)lhs); } public LuaValue subFrom(int lhs) { return subFrom((double)lhs); }
public LuaValue mul( LuaValue rhs ) { return aritherror("mul"); } public LuaValue mul( LuaValue rhs ) { return arithmt(MUL,rhs); }
public LuaValue mul(double rhs) { return aritherror("mul"); } public LuaValue mul(double rhs) { return aritherror("mul"); }
public LuaValue mul(int rhs) { return mul((double)rhs); } public LuaValue mul(int rhs) { return mul((double)rhs); }
public LuaValue pow( LuaValue rhs ) { return aritherror("pow"); } public LuaValue pow( LuaValue rhs ) { return arithmt(POW,rhs); }
public LuaValue pow( double rhs ) { return aritherror("pow"); } public LuaValue pow( double rhs ) { return aritherror("pow"); }
public LuaValue pow( int rhs ) { return aritherror("pow"); } public LuaValue pow( int rhs ) { return aritherror("pow"); }
public LuaValue powWith(double lhs) { return aritherror("mul"); } public LuaValue powWith(double lhs) { return aritherror("mul"); }
public LuaValue powWith(int lhs) { return powWith((double)lhs); } public LuaValue powWith(int lhs) { return powWith((double)lhs); }
public LuaValue div( LuaValue rhs ) { return aritherror("div"); } public LuaValue div( LuaValue rhs ) { return arithmt(DIV,rhs); }
public LuaValue div( double rhs ) { return aritherror("div"); } public LuaValue div( double rhs ) { return aritherror("div"); }
public LuaValue div( int rhs ) { return aritherror("div"); } public LuaValue div( int rhs ) { return aritherror("div"); }
public LuaValue divInto(double lhs) { return aritherror("divInto"); } public LuaValue divInto(double lhs) { return aritherror("divInto"); }
public LuaValue mod( LuaValue rhs ) { return aritherror("mod"); } public LuaValue mod( LuaValue rhs ) { return arithmt(MOD,rhs); }
public LuaValue mod( double rhs ) { return aritherror("mod"); } public LuaValue mod( double rhs ) { return aritherror("mod"); }
public LuaValue mod( int rhs ) { return aritherror("mod"); } public LuaValue mod( int rhs ) { return aritherror("mod"); }
public LuaValue modFrom(double lhs) { return aritherror("modFrom"); } public LuaValue modFrom(double lhs) { return aritherror("modFrom"); }
protected LuaValue arithmt(LuaValue tag, LuaValue op2) {
LuaValue h = this.metatag(tag);
if ( h.isnil() )
h = op2.checkmetatag(tag, "attempt to perform arithmetic on ");
return h.call( this, op2 );
}
// relational operators // relational operators
public LuaValue lt( LuaValue rhs ) { return compareerror(rhs); } public LuaValue lt( LuaValue rhs ) { return compareerror(rhs); }

View File

@@ -1,14 +1,36 @@
print( '---- initial metatables' ) print( '---- initial metatables' )
local values = { 1, false, coroutine.create( function() end ) } local anumber = 111
local aboolean = false
local afunction = function() end
local athread = coroutine.create( afunction )
local values = { athread, aboolean, afunction, athread }
for i=1,#values do for i=1,#values do
print( debug.getmetatable( values[i] ) ) print( debug.getmetatable( values[i] ) )
end end
local ts = tostring
tostring = function(o)
local t = type(o)
return (t=='thread' or t=='function') and t or ts(o)
end
local buildbin = function(name)
return function(a,b)
print( 'mt.__'..name..'()', type(a), type(b), a, b )
return '__'..name..'-result'
end
end
local mt = { local mt = {
__call=function(a,b,c) __call=function(a,b,c)
print( 'mt.__call()', type(a), type(b), type(c), b, c ) print( 'mt.__call()', type(a), type(b), type(c), b, c )
return '__call-result' return '__call-result'
end, end,
__add=buildbin('add'),
__sub=buildbin('sub'),
__mul=buildbin('mul'),
__div=buildbin('div'),
__pow=buildbin('pow'),
__mod=buildbin('mod'),
} }
-- pcall a function and check for a pattern in the error string -- pcall a function and check for a pattern in the error string
@@ -30,6 +52,34 @@ for i=1,#values do
print( debug.setmetatable( values[i], nil ) ) print( debug.setmetatable( values[i], nil ) )
end end
print( '---- __add, __sub, __mul, __div, __pow, __mod' )
local groups = { {aboolean, aboolean}, {aboolean, athread}, {aboolean, afunction}, {aboolean, "abc"} }
for i=1,#groups do
local a,b = groups[i][1], groups[i][2]
print( type(values[i]), 'before', ecall( 'attempt to perform arithmetic', function() return a+b end ) )
print( type(values[i]), 'before', ecall( 'attempt to perform arithmetic', function() return b+a end ) )
print( type(values[i]), 'before', ecall( 'attempt to perform arithmetic', function() return a-b end ) )
print( type(values[i]), 'before', ecall( 'attempt to perform arithmetic', function() return b-a end ) )
print( type(values[i]), 'before', ecall( 'attempt to perform arithmetic', function() return a*b end ) )
print( type(values[i]), 'before', ecall( 'attempt to perform arithmetic', function() return b*a end ) )
print( type(values[i]), 'before', ecall( 'attempt to perform arithmetic', function() return a^b end ) )
print( type(values[i]), 'before', ecall( 'attempt to perform arithmetic', function() return b^a end ) )
print( type(values[i]), 'before', ecall( 'attempt to perform arithmetic', function() return a%b end ) )
print( type(values[i]), 'before', ecall( 'attempt to perform arithmetic', function() return b%a end ) )
print( debug.setmetatable( a, mt ) )
print( type(values[i]), 'after', pcall( function() return a+b end ) )
print( type(values[i]), 'after', pcall( function() return b+a end ) )
print( type(values[i]), 'after', pcall( function() return a-b end ) )
print( type(values[i]), 'after', pcall( function() return b-a end ) )
print( type(values[i]), 'after', pcall( function() return a*b end ) )
print( type(values[i]), 'after', pcall( function() return b*a end ) )
print( type(values[i]), 'after', pcall( function() return a^b end ) )
print( type(values[i]), 'after', pcall( function() return b^a end ) )
print( type(values[i]), 'after', pcall( function() return a%b end ) )
print( type(values[i]), 'after', pcall( function() return b%a end ) )
print( debug.setmetatable( a, nil ) )
end
print( '---- final metatables' ) print( '---- final metatables' )