// set expression class -*- c++ -*-

#ifdef __GNUC__
# pragma implementation
#endif // __GNUC__
#include "SetExpression.h"
#include "Net.h"
#include "Constant.h"
#include "EmptySet.h"
#include "BoolType.h"
#include "LeafValue.h"
#include "PlaceMarking.h"
#include "Printer.h"

/** @file SetExpression.C
 * Basic operations on multi-sets
 */

/* Copyright  1998-2002 Marko Mkel (msmakela@tcs.hut.fi).

   This file is part of MARIA, a reachability analyzer and model checker
   for high-level Petri nets.

   MARIA is free software; you can redistribute it and/or modify it
   under the terms of the GNU General Public License as published by
   the Free Software Foundation; either version 2, or (at your option)
   any later version.

   MARIA is distributed in the hope that it will be useful, but
   WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
   General Public License for more details.

   The GNU General Public License is often shipped with GNU software, and
   is generally kept in a file called COPYING or LICENSE.  If you do not
   have a copy of the license, write to the Free Software Foundation,
   59 Temple Place, Suite 330, Boston, MA 02111 USA. */

SetExpression::SetExpression (enum Op op,
			      class Expression& left,
			      class Expression& right) :
  Expression (),
  myOp (op), myLeft (&left), myRight (&right)
{
  assert (myLeft->getType () || myRight->getType ());
  assert (!myLeft->getType () || !myRight->getType () ||
	  myLeft->getType () == myRight->getType ());
  assert (!myLeft->isTemporal () && !myRight->isTemporal ());
  // the following assertions should be guaranteed by construct ()
  assert (!myRight->getKind () != eEmptySet);
  assert (myLeft->getKind () != eEmptySet ||
	  myOp == sEquals);
  switch (myOp) {
  case sSubset:
  case sEquals:
    setType (Net::getBoolType ());
    break;
  case sIntersection:
  case sMinus:
  case sUnion:
    setType (*myLeft->getType ());
    break;
  }
}

SetExpression::~SetExpression ()
{
  myLeft->destroy ();
  myRight->destroy ();
}

class Expression*
SetExpression::construct (enum Op op,
			  class Expression& left,
			  class Expression& right)
{
  if (left == right) {
    switch (op) {
    case sSubset:
    case sEquals:
    True:
      right.destroy ();
      left.destroy ();
      return (new class Constant (*new class LeafValue
				  (Net::getBoolType (), true)))->cse ();
    case sMinus:
      right.destroy ();
      left.destroy ();
      return (new class EmptySet ())->cse ();
      break;
    case sIntersection:
    case sUnion:
      right.destroy ();
      return &left;
    }
  }
  else if (left.getKind () == eEmptySet) {
    switch (op) {
    case sSubset:
      goto True;
    case sEquals:
      return (new class SetExpression (op, left, right))->cse ();
    case sMinus:
    case sIntersection:
      right.destroy ();
      return &left;
    case sUnion:
      left.destroy ();
      return &right;
    }
  }
  else if (right.getKind () == eEmptySet) {
    switch (op) {
    case sSubset:
    case sEquals:
      return (new class SetExpression (sEquals, right, left))->cse ();
    case sMinus:
      right.destroy ();
      return &left;
    case sIntersection:
    case sUnion:
      left.destroy ();
      return &right;
    }
  }
  else {
    switch (op) {
    case sEquals:
    case sIntersection:
    case sUnion:
      if (right < left) // transform the expression to canonical form
	return (new class SetExpression (op, right, left))->cse ();
      // fall through
    case sMinus:
    case sSubset:
      return (new class SetExpression (op, left, right))->cse ();
    }
  }

  assert (false);
  return NULL;
}

class Value*
SetExpression::do_eval (const class Valuation& valuation) const
{
  class PlaceMarking* l = myLeft->meval (valuation);
  if (!l)
    return NULL;
  class PlaceMarking* r = myRight->meval (valuation);
  if (!r) {
    delete l;
    return NULL;
  }

  bool result;

  switch (myOp) {
  case sSubset:
    result = *l <= *r;
    break;
  case sEquals:
    result = *l == *r;
    break;
  default:
    assert (false);
    return NULL;
  }

  delete l;
  delete r;
  return constrain (valuation, new class LeafValue (*getType (), result));
}

class PlaceMarking*
SetExpression::meval (const class Valuation& valuation) const
{
  class PlaceMarking* l = myLeft->meval (valuation);
  if (!l)
    return NULL;
  class PlaceMarking* r = myRight->meval (valuation);
  if (!r) {
    delete l;
    return NULL;
  }

  switch (myOp) {
  case sIntersection:
    if (l->size () < r->size ()) {
      class PlaceMarking* p = r; r = l; l = p;
    }
    *l &= *r;
    break;
  case sMinus:
    *l -= *r;
    break;
  case sUnion:
    if (l->size () < r->size ()) {
      class PlaceMarking* p = r; r = l; l = p;
    }
    if (!l->add (*r, 1)) {
      valuation.flag (errCard, *this);
      delete l;
      delete r;
      return NULL;
    }
    break;
  default:
    assert (false);
    return NULL;
  }

  delete r;
  l->setPlace (NULL);
  return l;
}

class Expression*
SetExpression::ground (const class Valuation& valuation,
		       class Transition* transition,
		       bool declare)
{
  class Expression* left = myLeft->ground (valuation, transition, declare);
  if (!left) return NULL;
  class Expression* right = myRight->ground (valuation, transition, declare);
  if (!right) { left->destroy (); return NULL; }

  assert (valuation.isOK ());

  if (left == myLeft && right == myRight) {
    left->destroy ();
    right->destroy ();
    return copy ();
  }
  else
    return static_cast<class Expression*>
      (new class SetExpression (myOp, *left, *right))->ground (valuation);
}

class Expression*
SetExpression::substitute (class Substitution& substitution)
{
  class Expression* left = myLeft->substitute (substitution);
  class Expression* right = myRight->substitute (substitution);

  if (left == myLeft && right == myRight) {
    left->destroy ();
    right->destroy ();
    return copy ();
  }
  else
    return (new class SetExpression (myOp, *left, *right))->cse ();
}

bool
SetExpression::depends (const class VariableSet& vars,
			bool complement) const
{
  return
    myLeft->depends (vars, complement) ||
    myRight->depends (vars, complement);
}

bool
SetExpression::forExpressions (bool (*operation)
			       (const class Expression&,void*),
			       void* data) const
{
  return
    (*operation) (*this, data) &&
    myLeft->forExpressions (operation, data) &&
    myRight->forExpressions (operation, data);
}

#ifdef EXPR_COMPILE
# include "CExpression.h"

void
SetExpression::compile (class CExpression& cexpr,
			unsigned indent,
			const char* lvalue,
			const class VariableSet* vars) const
{
  assert (myOp == sSubset || myOp == sEquals);
  char* left;
  char* right;
  class StringBuffer& out = cexpr.getOut ();
  if (cexpr.getVariable (*myLeft, left))
    myLeft->compileMset (cexpr, indent, 0, left, vars);
  if (cexpr.getVariable (*myRight, right))
    myRight->compileMset (cexpr, indent, 0, right, vars);
  out.indent (indent);
  out.append (lvalue);
  out.append ("=");
  if (myLeft->getKind () == eEmptySet) {
    assert (myOp == sEquals);
    out.append ("!nonempty (");
    out.append (right);
    out.append (");\n");
  }
  else {
    out.append (myOp == sSubset ? "subset" : "equal");
    const class Type* type = myLeft->getType ();
    if (!type) type = myRight->getType ();
    type->appendIndex (out);
    out.append (" (");
    out.append (left);
    out.append (", ");
    out.append (right);
    out.append (");\n");
  }
  delete[] left;
  delete[] right;
  compileConstraint (cexpr, indent, lvalue);
}

void
SetExpression::compileMset (class CExpression& cexpr,
			    unsigned indent,
			    const char* resulttype,
			    const char* result,
			    const class VariableSet* vars) const
{
  myLeft->compileMset (cexpr, indent, resulttype, result, vars);
  if (myOp == sUnion)
    myRight->compileMset (cexpr, indent, resulttype, result, vars);
  else {
    assert (myOp == sIntersection || myOp == sMinus);
    char* right;
    class StringBuffer& out = cexpr.getOut ();
    if (cexpr.getVariable (*myRight, right))
      myRight->compileMset (cexpr, indent, 0, right, vars);
    out.indent (indent);
    out.append (myOp == sIntersection ? "intersect" : "subtract");
    getType ()->appendIndex (out);
    out.append (" (");
    if (resulttype)
      out.append (resulttype);
    out.append (result);
    out.append (", ");
    out.append (right);
    out.append (");\n");
    delete[] right;
  }
}

#endif // EXPR_COMPILE


/** Determine whether an expression requires a type cast
 * @param kind	kind of the expression
 * @return	true or false, according to whether a cast is required
 */
inline static bool
needsCast (enum Expression::Kind kind)
{
  switch (kind) {
  case Expression::eVariable:
  case Expression::eStructComponent:
  case Expression::eUnionComponent:
  case Expression::eUnionType:
  case Expression::eVectorIndex:
  case Expression::eUnop:
  case Expression::eBinop:
  case Expression::eBooleanBinop:
  case Expression::eNot:
  case Expression::eRelop:
  case Expression::eBufferIndex:
  case Expression::eSet:
  case Expression::eIfThenElse:
  case Expression::eTemporalBinop:
  case Expression::eTemporalUnop:
  case Expression::eTypecast:
  case Expression::eCardinality:
  case Expression::eTransitionQualifier:
  case Expression::ePlaceContents:
  case Expression::eSubmarking:
  case Expression::eMapping:
    return false;
  default:
    return true;
  }
}

/** Convert an operator to a string
 * @param op	the operator to convert
 * @return	a string corresponding to the operator
 */
inline static const char*
getOpString (enum SetExpression::Op op)
{
  switch (op) {
  case SetExpression::sSubset:
    return "subset";
  case SetExpression::sEquals:
    return "equals";
  case SetExpression::sIntersection:
    return "intersect";
  case SetExpression::sMinus:
    return "minus";
  case SetExpression::sUnion:
    return "union";
  }

  return "???";
}

void
SetExpression::display (const class Printer& printer) const
{
  const class Type* type = myLeft->getType ();
  if (!type)
    type = myRight->getType ();
  const char* cast = type ? type->getName () : 0;
  if (cast) {
    switch (type->getKind ()) {
    case Type::tBool:
    case Type::tChar:
    case Type::tInt:
    case Type::tCard:
      cast = 0;
    default:
      break;
    }
  }

  printer.delimiter ('(')++;
  if (cast && ::needsCast (myLeft->getKind ())) {
    printer.printRaw ("is");
    printer.delimiter (' ');
    printer.print (cast);
    printer.delimiter (' ');
  }
  myLeft->display (printer);
  --printer.delimiter (')');

  printer.printRaw (::getOpString (myOp));

  printer.delimiter ('(')++;
  if (cast && ::needsCast (myRight->getKind ())) {
    printer.printRaw ("is");
    printer.delimiter (' ');
    printer.print (cast);
    printer.delimiter (' ');
  }
  myRight->display (printer);
  --printer.delimiter (')');
}
