diff --git a/src/core/org/luaj/vm2/LuaBoolean.java b/src/core/org/luaj/vm2/LuaBoolean.java index 31edbc3a..c8d3881f 100644 --- a/src/core/org/luaj/vm2/LuaBoolean.java +++ b/src/core/org/luaj/vm2/LuaBoolean.java @@ -69,5 +69,5 @@ public class LuaBoolean extends LuaValue { public LuaValue getmetatable() { return s_metatable; - } + } } diff --git a/src/core/org/luaj/vm2/LuaDouble.java b/src/core/org/luaj/vm2/LuaDouble.java index 83e81101..5fd1a903 100644 --- a/src/core/org/luaj/vm2/LuaDouble.java +++ b/src/core/org/luaj/vm2/LuaDouble.java @@ -75,7 +75,8 @@ public class LuaDouble extends LuaNumber { public boolean equals(Object o) { return o instanceof LuaDouble? ((LuaDouble)o).v == v: false; } // arithmetic equality - public boolean eq_b( LuaValue rhs ) { return rhs.eq_b(v); } + public LuaValue eq( LuaValue rhs ) { return rhs.eq_b(v)? TRUE: FALSE; } + public boolean eq_b( LuaValue rhs ) { return rhs.eq_b(v); } public boolean eq_b( double rhs ) { return v == rhs; } public boolean eq_b( int rhs ) { return v == rhs; } diff --git a/src/core/org/luaj/vm2/LuaInteger.java b/src/core/org/luaj/vm2/LuaInteger.java index 27ad103e..9663ba85 100644 --- a/src/core/org/luaj/vm2/LuaInteger.java +++ b/src/core/org/luaj/vm2/LuaInteger.java @@ -105,7 +105,8 @@ public class LuaInteger extends LuaNumber { public boolean equals(Object o) { return o instanceof LuaInteger? ((LuaInteger)o).v == v: false; } // arithmetic equality - public boolean eq_b( LuaValue rhs ) { return rhs.eq_b(v); } + public LuaValue eq( LuaValue rhs ) { return rhs.eq_b(v)? TRUE: FALSE; } + public boolean eq_b( LuaValue rhs ) { return rhs.eq_b(v); } public boolean eq_b( double rhs ) { return v == rhs; } public boolean eq_b( int rhs ) { return v == rhs; } diff --git a/src/core/org/luaj/vm2/LuaTable.java b/src/core/org/luaj/vm2/LuaTable.java index ad81eae1..154bf7be 100644 --- a/src/core/org/luaj/vm2/LuaTable.java +++ b/src/core/org/luaj/vm2/LuaTable.java @@ -286,6 +286,18 @@ public class LuaTable extends LuaValue { public LuaValue len() { return LuaInteger.valueOf(length()); } + + public LuaValue eq( LuaValue rhs ) { + return rhs.eq_b(this)? TRUE: FALSE; + } + + public boolean eq_b( LuaValue rhs ) { + return rhs.eq_b(this); + } + + public boolean eq_b( LuaTable val ) { + return this == val || val.eqmt_b(this); + } public int maxn() { int n = 0; diff --git a/src/core/org/luaj/vm2/LuaUserdata.java b/src/core/org/luaj/vm2/LuaUserdata.java index 82935246..8df3245b 100644 --- a/src/core/org/luaj/vm2/LuaUserdata.java +++ b/src/core/org/luaj/vm2/LuaUserdata.java @@ -96,14 +96,23 @@ public class LuaUserdata extends LuaValue { } public boolean equals( Object val ) { + if ( this == val ) + return true; if ( ! (val instanceof LuaUserdata) ) return false; LuaUserdata u = (LuaUserdata) val; return m_instance.equals(u.m_instance); } - public boolean eq_b( LuaValue val ) { - return equals(val); + public LuaValue eq( LuaValue rhs ) { + return rhs.eq_b(this)? TRUE: FALSE; + } + + public boolean eq_b( LuaValue rhs ) { + return rhs.eq_b(this); + } + + public boolean eq_b( LuaUserdata val ) { + return this == val || m_instance.equals(val.m_instance) || val.eqmt_b(this); } - } diff --git a/src/core/org/luaj/vm2/LuaValue.java b/src/core/org/luaj/vm2/LuaValue.java index 77ba0442..987a8a71 100644 --- a/src/core/org/luaj/vm2/LuaValue.java +++ b/src/core/org/luaj/vm2/LuaValue.java @@ -262,22 +262,20 @@ public class LuaValue extends Varargs { public boolean equals(Object obj) { return this == obj; } // arithmetic equality - public LuaValue eq( LuaValue val ) { return eqmt(val); } + public LuaValue eq( LuaValue val ) { return eq_b(val)? TRUE: FALSE; } public boolean eq_b( LuaValue val ) { return this == val; } + public boolean eq_b( LuaTable val ) { return false; } + public boolean eq_b( LuaUserdata val ) { return false; } public boolean eq_b( LuaString val ) { return false; } public boolean eq_b( double val ) { return false; } public boolean eq_b( int val ) { return false; } - public LuaValue neq( LuaValue val ) { return eq(val).not(); } - public boolean neq_b( LuaValue val ) { return this != val; } + public LuaValue neq( LuaValue val ) { return eq_b(val)? FALSE: TRUE; } + public boolean neq_b( LuaValue val ) { return ! eq_b(val); } public boolean neq_b( double val ) { return ! eq_b(val); } public boolean neq_b( int val ) { return ! eq_b(val); } - public LuaValue eqmt( LuaValue op2 ) { - if ( type() != op2.type() ) - return FALSE; - if ( eq_b(op2) ) - return TRUE; + protected boolean eqmt_b( LuaValue op2 ) { LuaValue h = metatag(EQ); - return !h.isnil() && h == op2.metatag(EQ)? h.call(this,op2): FALSE; + return !h.isnil() && h==op2.metatag(EQ)? h.call(this,op2).toboolean(): false; } // arithmetic operators diff --git a/test/junit/org/luaj/vm2/UnaryBinaryOperatorsTest.java b/test/junit/org/luaj/vm2/UnaryBinaryOperatorsTest.java index 3195226a..9cf0308d 100644 --- a/test/junit/org/luaj/vm2/UnaryBinaryOperatorsTest.java +++ b/test/junit/org/luaj/vm2/UnaryBinaryOperatorsTest.java @@ -176,6 +176,15 @@ public class UnaryBinaryOperatorsTest extends TestCase { // check arithmetic equality among different types assertEquals(ia.eq(sa),LuaValue.FALSE); assertEquals(sa.eq(ia),LuaValue.FALSE); + + // equals with mismatched types + LuaValue t = new LuaTable(); + assertEquals(ia.eq(t),LuaValue.FALSE); + assertEquals(t.eq(ia),LuaValue.FALSE); + assertEquals(ia.eq(LuaValue.FALSE),LuaValue.FALSE); + assertEquals(LuaValue.FALSE.eq(ia),LuaValue.FALSE); + assertEquals(ia.eq(LuaValue.NIL),LuaValue.FALSE); + assertEquals(LuaValue.NIL.eq(ia),LuaValue.FALSE); } public void testEqDouble() { @@ -191,6 +200,15 @@ public class UnaryBinaryOperatorsTest extends TestCase { // check arithmetic equality among different types assertEquals(da.eq(sa),LuaValue.FALSE); assertEquals(sa.eq(da),LuaValue.FALSE); + + // equals with mismatched types + LuaValue t = new LuaTable(); + assertEquals(da.eq(t),LuaValue.FALSE); + assertEquals(t.eq(da),LuaValue.FALSE); + assertEquals(da.eq(LuaValue.FALSE),LuaValue.FALSE); + assertEquals(LuaValue.FALSE.eq(da),LuaValue.FALSE); + assertEquals(da.eq(LuaValue.NIL),LuaValue.FALSE); + assertEquals(LuaValue.NIL.eq(da),LuaValue.FALSE); } public void testAdd() { diff --git a/test/lua/metatags.lua b/test/lua/metatags.lua index 2cf2a677..65102437 100644 --- a/test/lua/metatags.lua +++ b/test/lua/metatags.lua @@ -3,7 +3,8 @@ local anumber = 111 local aboolean = false local afunction = function() end local athread = coroutine.create( afunction ) -local values = { anumber, aboolean, afunction, athread } +local atable = {} +local values = { anumber, aboolean, afunction, athread, atable } for i=1,#values do print( debug.getmetatable( values[i] ) ) end @@ -68,7 +69,7 @@ for i=1,#values do end print( '---- __add, __sub, __mul, __div, __pow, __mod' ) -local groups = { {aboolean, aboolean}, {aboolean, athread}, {aboolean, afunction}, {aboolean, "abc"} } +local groups = { {aboolean, aboolean}, {aboolean, athread}, {aboolean, afunction}, {aboolean, "abc"}, {aboolean, atable} } for i=1,#groups do local a,b = groups[i][1], groups[i][2] print( type(a), type(b), 'before', ecall( 'attempt to perform arithmetic', function() return a+b end ) ) @@ -106,7 +107,7 @@ for i=1,#values do end print( '---- __neg' ) -values = { aboolean, afunction, athread, "abcd" } +values = { aboolean, afunction, athread, "abcd", atable } for i=1,#values do print( type(values[i]), 'before', ecall( 'attempt to get length of ', function() return #values[i] end ) ) print( debug.setmetatable( values[i], mt ) ) @@ -118,7 +119,7 @@ print( '---- __eq, __lt, __le, same types' ) local bfunction = function() end local bthread = coroutine.create( bfunction ) local groups -groups = { {afunction, bfunction}, {true, true}, {true, false}, {afunction, bfunction}, {athread, bthread}, } +groups = { {afunction, bfunction}, {true, true}, {true, false}, {afunction, bfunction}, {athread, bthread}, {atable, atable}, {atable, {}} } for i=1,#groups do local a,b = groups[i][1], groups[i][2] print( type(a), type(b), 'before', pcall( function() return a==b end ) ) @@ -162,7 +163,7 @@ for i=1,#groups do end print( '---- __tostring' ) -values = { aboolean, afunction, athread } +values = { aboolean, afunction, athread, atable, "abc" } for i=1,#values do local a = values[i] print( debug.setmetatable( a, mt ) ) @@ -171,7 +172,7 @@ for i=1,#values do end print( '---- __metatable' ) -values = { aboolean, afunction, athread } +values = { aboolean, afunction, athread, atable, "abc" } for i=1,#values do local a = values[i] print( type(a), 'before', pcall( function() return debug.getmetatable(a), getmetatable(a) end ) )