diff --git a/src/core/org/luaj/vm2/LuaValue.java b/src/core/org/luaj/vm2/LuaValue.java index 86908888..40191ff7 100644 --- a/src/core/org/luaj/vm2/LuaValue.java +++ b/src/core/org/luaj/vm2/LuaValue.java @@ -74,6 +74,9 @@ public class LuaValue extends Varargs { public static final LuaString MOD = valueOf("__mod"); public static final LuaString UNM = valueOf("__unm"); public static final LuaString LEN = valueOf("__len"); + public static final LuaString EQ = valueOf("__eq"); + public static final LuaString LT = valueOf("__lt"); + public static final LuaString LE = valueOf("__le"); public static final LuaString EMPTYSTRING = valueOf(""); private static int MAXSTACK = 250; @@ -255,19 +258,27 @@ public class LuaValue extends Varargs { public LuaValue getn() { return typerror("getn"); } // object equality, used for key comparison - public boolean equals(Object obj) { return this == obj; } + public boolean equals(Object obj) { return this == obj; } // arithmetic equality - public LuaValue eq( LuaValue val ) { return valueOf(eq_b(val)); } - public boolean eq_b( LuaValue val ) { return this == val; } - public boolean eq_b( LuaString val ) { return this == val; } - public boolean eq_b( double val ) { return false; } - public boolean eq_b( int val ) { return false; } - public LuaValue neq( LuaValue val ) { return valueOf(!eq_b(val)); } - public boolean neq_b( LuaValue val ) { return ! eq_b(val); } - public boolean neq_b( double val ) { return ! eq_b(val); } - public boolean neq_b( int val ) { return ! eq_b(val); } - + public LuaValue eq( LuaValue val ) { return eqmt(val); } + public boolean eq_b( LuaValue val ) { return this == val; } + public boolean eq_b( LuaString val ) { return false; } + public boolean eq_b( double val ) { return false; } + public boolean eq_b( int val ) { return false; } + public LuaValue neq( LuaValue val ) { return eq(val).not(); } + public boolean neq_b( LuaValue val ) { return this != val; } + public boolean neq_b( double val ) { return ! eq_b(val); } + public boolean neq_b( int val ) { return ! eq_b(val); } + public LuaValue eqmt( LuaValue op2 ) { + if ( type() != op2.type() ) + return FALSE; + if ( eq_b(op2) ) + return TRUE; + LuaValue h = metatag(EQ); + return !h.isnil() && h == op2.metatag(EQ)? h.call(this,op2): FALSE; + } + // arithmetic operators public LuaValue add( LuaValue rhs ) { return arithmt(ADD,rhs); } public LuaValue add(double rhs) { return aritherror("add"); } @@ -301,31 +312,40 @@ public class LuaValue extends Varargs { } // relational operators - public LuaValue lt( LuaValue rhs ) { return compareerror(rhs); } + public LuaValue lt( LuaValue rhs ) { return comparemt(LT,rhs); } public LuaValue lt( double rhs ) { return compareerror("number"); } public LuaValue lt( int rhs ) { return compareerror("number"); } - public boolean lt_b( LuaValue rhs ) { compareerror(rhs); return false; } + public boolean lt_b( LuaValue rhs ) { return comparemt(LT,rhs).toboolean(); } public boolean lt_b( int rhs ) { compareerror("number"); return false; } public boolean lt_b( double rhs ) { compareerror("number"); return false; } - public LuaValue lteq( LuaValue rhs ) { return compareerror(rhs); } + public LuaValue lteq( LuaValue rhs ) { return comparemt(LE,rhs); } public LuaValue lteq( double rhs ) { return compareerror("number"); } public LuaValue lteq( int rhs ) { return compareerror("number"); } - public boolean lteq_b( LuaValue rhs ) { compareerror(rhs); return false; } + public boolean lteq_b( LuaValue rhs ) { return comparemt(LE,rhs).toboolean(); } public boolean lteq_b( int rhs ) { compareerror("number"); return false; } public boolean lteq_b( double rhs ) { compareerror("number"); return false; } - public LuaValue gt( LuaValue rhs ) { return compareerror(rhs); } + public LuaValue gt( LuaValue rhs ) { return rhs.comparemt(LE,this); } public LuaValue gt( double rhs ) { return compareerror("number"); } public LuaValue gt( int rhs ) { return compareerror("number"); } - public boolean gt_b( LuaValue rhs ) { compareerror(rhs); return false; } + public boolean gt_b( LuaValue rhs ) { return rhs.comparemt(LE,this).toboolean(); } public boolean gt_b( int rhs ) { compareerror("number"); return false; } public boolean gt_b( double rhs ) { compareerror("number"); return false; } - public LuaValue gteq( LuaValue rhs ) { return compareerror("number"); } + public LuaValue gteq( LuaValue rhs ) { return rhs.comparemt(LT,this); } public LuaValue gteq( double rhs ) { return compareerror("number"); } public LuaValue gteq( int rhs ) { return valueOf(todouble() >= rhs); } - public boolean gteq_b( LuaValue rhs ) { compareerror(rhs); return false; } + public boolean gteq_b( LuaValue rhs ) { return rhs.comparemt(LT,this).toboolean(); } public boolean gteq_b( int rhs ) { compareerror("number"); return false; } public boolean gteq_b( double rhs ) { compareerror("number"); return false; } - + public LuaValue comparemt( LuaValue tag, LuaValue op1 ) { + if ( type() == op1.type() ) { + LuaValue h = metatag(tag); + if ( !h.isnil() && h == op1.metatag(tag) ) + return h.call(this, op1); + } + return error("attempt to compare "+typename()); + } + + // string comparison public int strcmp( LuaValue rhs ) { error("attempt to compare "+typename()); return 0; } public int strcmp( LuaString rhs ) { error("attempt to compare "+typename()); return 0; } diff --git a/test/lua/metatags.lua b/test/lua/metatags.lua index ac151612..f6fac66f 100644 --- a/test/lua/metatags.lua +++ b/test/lua/metatags.lua @@ -8,9 +8,15 @@ for i=1,#values do print( debug.getmetatable( values[i] ) ) end local ts = tostring +local tb,count = {},0 tostring = function(o) local t = type(o) - return (t=='thread' or t=='function') and t or ts(o) + if t~='thread' and t~='function' then return ts(o) end + if not tb[o] then + count = count + 1 + tb[o] = t..'.'..count + end + return tb[o] end local buildunop = function(name) @@ -40,6 +46,9 @@ local mt = { __mod=buildbinop('mod'), __unm=buildunop('unm'), __len=buildunop('neg'), + __eq=buildbinop('eq'), + __lt=buildbinop('lt'), + __le=buildbinop('le'), } -- pcall a function and check for a pattern in the error string @@ -107,4 +116,48 @@ for i=1,#values do print( debug.setmetatable( values[i], nil ) ) end - \ No newline at end of file +print( '---- __eq, __lt, __le, same types' ) +local bfunction = function() end +local bthread = coroutine.create( bfunction ) +local groups +groups = { {afunction, bfunction}, {true, true}, {true, false}, {afunction, bfunction}, {athread, bthread}, } +for i=1,#groups do + local a,b = groups[i][1], groups[i][2] + print( type(values[i]), 'before', pcall( function() return a==b end ) ) + print( type(values[i]), 'before', pcall( function() return a~=b end ) ) + print( type(values[i]), 'before', ecall( 'attempt to compare', function() return ab end ) ) + print( type(values[i]), 'before', ecall( 'attempt to compare', function() return a>=b end ) ) + print( debug.setmetatable( a, mt ) ) + print( debug.setmetatable( b, mt ) ) + print( type(values[i]), 'after', pcall( function() return a==b end ) ) + print( type(values[i]), 'after', pcall( function() return a~=b end ) ) + print( type(values[i]), 'after', pcall( function() return ab end ) ) + print( type(values[i]), 'after', pcall( function() return a>=b end ) ) + print( debug.setmetatable( a, nil ) ) + print( debug.setmetatable( b, nil ) ) +end +print( '---- __eq, __lt, __le, different types' ) +groups = { {aboolean, athread}, } +for i=1,#groups do + local a,b = groups[i][1], groups[i][2] + print( type(values[i]), 'before', pcall( function() return a==b end ) ) + print( type(values[i]), 'before', pcall( function() return a~=b end ) ) + print( type(values[i]), 'before', ecall( 'attempt to compare', function() return ab end ) ) + print( type(values[i]), 'before', ecall( 'attempt to compare', function() return a>=b end ) ) + print( debug.setmetatable( a, mt ) ) + print( debug.setmetatable( b, mt ) ) + print( type(values[i]), 'after-a', pcall( function() return a==b end ) ) + print( type(values[i]), 'after-a', pcall( function() return a~=b end ) ) + print( type(values[i]), 'after-a', ecall( 'attempt to compare', function() return ab end ) ) + print( type(values[i]), 'after-a', ecall( 'attempt to compare', function() return a>=b end ) ) + print( debug.setmetatable( a, nil ) ) + print( debug.setmetatable( b, nil ) ) +end