/*
 * Project: MoleCuilder
 * Description: creates and alters molecular systems
 * Copyright (C)  2010-2012 University of Bonn. All rights reserved.
 * Copyright (C)  2013 Frederik Heber. All rights reserved.
 * 
 *
 *   This file is part of MoleCuilder.
 *
 *    MoleCuilder is free software: you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation, either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    MoleCuilder is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with MoleCuilder.  If not, see <http://www.gnu.org/licenses/>.
 */

/*
 * AtomDescriptorUnitTest.cpp
 *
 *  Created on: Feb 9, 2010
 *      Author: crueger
 */

// include config.h
#ifdef HAVE_CONFIG_H
#include <config.h>
#endif

#include "AtomDescriptorUnitTest.hpp"

#include <cppunit/CompilerOutputter.h>
#include <cppunit/extensions/TestFactoryRegistry.h>
#include <cppunit/ui/text/TestRunner.h>
#include <iostream>

#include <Descriptors/AtomDescriptor.hpp>
#include <Descriptors/AtomIdDescriptor.hpp>
#include <Descriptors/AtomOfMoleculeDescriptor.hpp>
#include <Descriptors/AtomOrderDescriptor.hpp>
#include <Descriptors/AtomsWithinDistanceOfDescriptor.hpp>

#include "World.hpp"
#include "Atom/atom.hpp"
#include "molecule.hpp"
#include "LinearAlgebra/Vector.hpp"

#ifdef HAVE_TESTRUNNER
#include "UnitTestMain.hpp"
#endif /*HAVE_TESTRUNNER*/

/********************************************** Test classes **************************************/
// Registers the fixture into the 'registry'
CPPUNIT_TEST_SUITE_REGISTRATION( AtomDescriptorTest );

// set up and tear down
void AtomDescriptorTest::setUp()
{
  World::getInstance();
  for(int i=0;i<ATOM_COUNT;++i){
    atoms[i]= World::getInstance().createAtom();
    atomIds[i]= atoms[i]->getId();
  }
}

void AtomDescriptorTest::tearDown()
{
  World::purgeInstance();
}

// some helper functions
static bool hasAllAtoms(std::vector<atom*> atoms,atomId_t ids[ATOM_COUNT], std::set<atomId_t> excluded = std::set<atomId_t>())
{
  for(int i=0;i<ATOM_COUNT;++i){
    atomId_t id = ids[i];
    if(!excluded.count(id)){
      std::vector<atom*>::iterator iter;
      bool res=false;
      for(iter=atoms.begin();iter!=atoms.end();++iter){
        res |= (*iter)->getId() == id;
      }
      if(!res) {
        cout << "Atom " << id << " missing in returned list" << endl;
        return false;
      }
    }
  }
  return true;
}

static bool hasNoDuplicateAtoms(std::vector<atom*> atoms)
{
  std::set<atomId_t> found;
  std::vector<atom*>::iterator iter;
  for(iter=atoms.begin();iter!=atoms.end();++iter){
    int id = (*iter)->getId();
    if(found.count(id))
      return false;
    found.insert(id);
  }
  return true;
}


void AtomDescriptorTest::AtomBaseSetsTest()
{
  std::vector<atom*> allAtoms = World::getInstance().getAllAtoms(AllAtoms());
  CPPUNIT_ASSERT_EQUAL( true , hasAllAtoms(allAtoms,atomIds));
  CPPUNIT_ASSERT_EQUAL( true , hasNoDuplicateAtoms(allAtoms));

  std::vector<atom*> noAtoms = World::getInstance().getAllAtoms(NoAtoms());
  CPPUNIT_ASSERT_EQUAL( true , noAtoms.empty());
}

void AtomDescriptorTest::AtomIdTest()
{
  // test Atoms from boundaries and middle of the set
  atom* testAtom;
  testAtom = World::getInstance().getAtom(AtomById(atomIds[0]));
  CPPUNIT_ASSERT(testAtom);
  CPPUNIT_ASSERT_EQUAL( atomIds[0], testAtom->getId());
  testAtom = World::getInstance().getAtom(AtomById(atomIds[ATOM_COUNT/2]));
  CPPUNIT_ASSERT(testAtom);
  CPPUNIT_ASSERT_EQUAL( atomIds[ATOM_COUNT/2], testAtom->getId());
  testAtom = World::getInstance().getAtom(AtomById(atomIds[ATOM_COUNT-1]));
  CPPUNIT_ASSERT(testAtom);
  CPPUNIT_ASSERT_EQUAL( atomIds[ATOM_COUNT-1], testAtom->getId());

  // find some ID that has not been created
  atomId_t outsideId=0;
  bool res = false;
  for(outsideId=0;!res;++outsideId) {
    res = true;
    for(int i = 0; i < ATOM_COUNT; ++i){
      res &= atomIds[i]!=outsideId;
    }
  }
  // test from outside of set
  testAtom = World::getInstance().getAtom(AtomById(outsideId));
  CPPUNIT_ASSERT(!testAtom);
}

void AtomDescriptorTest::AtomOfMoleculeTest()
{
  // test Atoms from boundaries and middle of the set
  atom* testAtom;
  testAtom = World::getInstance().getAtom(AtomById(atomIds[0]));
  CPPUNIT_ASSERT(testAtom);
  CPPUNIT_ASSERT_EQUAL( atomIds[0], testAtom->getId());

  // create some molecule and associate atom to it
  testAtom->setType(1);
  molecule * newmol = World::getInstance().createMolecule();
  newmol->AddAtom(testAtom);
  CPPUNIT_ASSERT_EQUAL(newmol->getId(), testAtom->getMolecule()->getId());

  // get atom by descriptor
  World::AtomComposite atoms = World::getInstance().getAllAtoms(AtomOfMolecule(newmol->getId()));
  CPPUNIT_ASSERT_EQUAL( (size_t)1, atoms.size() );
  CPPUNIT_ASSERT_EQUAL( (*atoms.begin())->getId(), testAtom->getId() );

  // remove molecule again
  World::getInstance().destroyMolecule(newmol);
}

void AtomDescriptorTest::AtomOrderTest()
{
  atom* testAtom;

  // test in normal order: 1, 2, ...
  for(int i=1;i<=ATOM_COUNT;++i){
    testAtom = World::getInstance().getAtom(AtomByOrder(i));
    CPPUNIT_ASSERT_EQUAL( atomIds[i-1], testAtom->getId());
  }

  // test in reverse order: -1, -2, ...
  for(int i=1; i<= ATOM_COUNT;++i){
    testAtom = World::getInstance().getAtom(AtomByOrder(-i));
    CPPUNIT_ASSERT_EQUAL( atomIds[(int)ATOM_COUNT-i], testAtom->getId());
  }

  // test from outside of set
  testAtom = World::getInstance().getAtom(AtomByOrder(ATOM_COUNT+1));
  CPPUNIT_ASSERT(!testAtom);
  testAtom = World::getInstance().getAtom(AtomByOrder(-ATOM_COUNT-1));
  CPPUNIT_ASSERT(!testAtom);
}


std::set<atomId_t> getDistanceList(const double distance, const Vector &position, atom **list)
{
  const double distanceSquared = distance*distance;
  std::set<atomId_t> reflist;
  for (size_t i=0; i<ATOM_COUNT;++i)
    if (list[i]->getPosition().DistanceSquared(position) < distanceSquared)
      reflist.insert ( list[i]->getId() );
  return reflist;
}


std::set<atomId_t> getIdList(const World::AtomComposite &list)
{
  std::set<atomId_t> testlist;
  for (World::AtomComposite::const_iterator iter = list.begin();
      iter != list.end(); ++iter)
    testlist.insert( (*iter)->getId() );
  return testlist;
}

//void AtomDescriptorTest::AtomsShapeTest()
//{
//  // align atoms along an axis
//  for(int i=0;i<ATOM_COUNT;++i) {
//    atoms[i]->setPosition(Vector((double)i, 0., 0.));
//    //std::cout << "atoms[" << i << "]: " << atoms[i]->getId() << " at " << atoms[i]->getPosition() << std::endl;
//  }
//
//  // get atom by descriptor ...
//  // ... from origin up to 2.5
//  {
//    const double distance = 1.5;
//    Vector position(0.,0.,0.);
//    Shape s = Sphere(position, distance);
//    World::AtomComposite atomlist = World::getInstance().getAllAtoms(AtomsByShape(s));
//    CPPUNIT_ASSERT_EQUAL( (size_t)2, atomlist.size() );
//    std::set<atomId_t> reflist = getDistanceList(distance, position, atoms);
//    std::set<atomId_t> testlist = getIdList(atomlist);
//    CPPUNIT_ASSERT_EQUAL( reflist, testlist );
//  }
//  // ... from (4,0,0) up to 2.9 (i.e. more shells or different view)
//  {
//    const double distance = 2.9;
//    Vector position(4.,0.,0.);
//    Shape s = Sphere(position, distance);
//    World::AtomComposite atomlist = World::getInstance().getAllAtoms(AtomsByShape(s));
//    CPPUNIT_ASSERT_EQUAL( (size_t)5, atomlist.size() );
//    std::set<atomId_t> reflist = getDistanceList(distance, position, atoms);
//    std::set<atomId_t> testlist = getIdList(atomlist);
//    CPPUNIT_ASSERT_EQUAL( reflist, testlist );
//  }
//  // ... from (10,0,0) up to 1.5
//  {
//    const double distance = 1.5;
//    Vector *position = new Vector(10.,0.,0.);
//    Shape s = Sphere(position, distance);
//    World::AtomComposite atomlist = World::getInstance().getAllAtoms(AtomsByShape(s));
//    CPPUNIT_ASSERT_EQUAL( (size_t)1, atomlist.size() );
//    std::set<atomId_t> reflist = getDistanceList(distance, *position, atoms);
//    std::set<atomId_t> testlist = getIdList(atomlist);
//    CPPUNIT_ASSERT_EQUAL( reflist, testlist );
//    delete position;
//  }
//}

void AtomDescriptorTest::AtomsWithinDistanceOfTest()
{
  // align atoms along an axis
  for(int i=0;i<ATOM_COUNT;++i) {
    atoms[i]->setPosition(Vector((double)i, 0., 0.));
    //std::cout << "atoms[" << i << "]: " << atoms[i]->getId() << " at " << atoms[i]->getPosition() << std::endl;
  }

  // get atom by descriptor ...
  // ... from origin up to 2.5
  {
    const double distance = 1.5;
    Vector position(0.,0.,0.);
    World::AtomComposite atomlist = World::getInstance().getAllAtoms(AtomsWithinDistanceOf(distance, position));
    CPPUNIT_ASSERT_EQUAL( (size_t)2, atomlist.size() );
    std::set<atomId_t> reflist = getDistanceList(distance, position, atoms);
    std::set<atomId_t> testlist = getIdList(atomlist);
    CPPUNIT_ASSERT( reflist == testlist );
  }
  // ... from (4,0,0) up to 2.9 (i.e. more shells or different view)
  {
    const double distance = 2.9;
    World::AtomComposite atomlist = World::getInstance().getAllAtoms(AtomsWithinDistanceOf(distance, Vector(4.,0.,0.)));
    CPPUNIT_ASSERT_EQUAL( (size_t)5, atomlist.size() );
    std::set<atomId_t> reflist = getDistanceList(distance, Vector(4.,0.,0.), atoms);
    std::set<atomId_t> testlist = getIdList(atomlist);
    CPPUNIT_ASSERT( reflist == testlist );
  }
  // ... from (10,0,0) up to 1.5
  {
    const double distance = 1.5;
    Vector *position = new Vector(10.,0.,0.);
    World::AtomComposite atomlist = World::getInstance().getAllAtoms(AtomsWithinDistanceOf(distance, *position));
    CPPUNIT_ASSERT_EQUAL( (size_t)1, atomlist.size() );
    std::set<atomId_t> reflist = getDistanceList(distance, *position, atoms);
    std::set<atomId_t> testlist = getIdList(atomlist);
    CPPUNIT_ASSERT( reflist == testlist );
    delete position;
  }
}

void AtomDescriptorTest::AtomCalcTest()
{
  // test some elementary set operations
  {
    std::vector<atom*> testAtoms = World::getInstance().getAllAtoms(AllAtoms()||NoAtoms());
    CPPUNIT_ASSERT_EQUAL( true , hasAllAtoms(testAtoms,atomIds));
    CPPUNIT_ASSERT_EQUAL( true , hasNoDuplicateAtoms(testAtoms));
  }

  {
    std::vector<atom*> testAtoms = World::getInstance().getAllAtoms(NoAtoms()||AllAtoms());
    CPPUNIT_ASSERT_EQUAL( true , hasAllAtoms(testAtoms,atomIds));
    CPPUNIT_ASSERT_EQUAL( true , hasNoDuplicateAtoms(testAtoms));
  }

  {
    std::vector<atom*> testAtoms = World::getInstance().getAllAtoms(NoAtoms()&&AllAtoms());
    CPPUNIT_ASSERT_EQUAL( true , testAtoms.empty());
  }

  {
    std::vector<atom*> testAtoms = World::getInstance().getAllAtoms(AllAtoms()&&NoAtoms());
    CPPUNIT_ASSERT_EQUAL( true , testAtoms.empty());
  }

  {
    std::vector<atom*> testAtoms = World::getInstance().getAllAtoms(!AllAtoms());
    CPPUNIT_ASSERT_EQUAL( true , testAtoms.empty());
  }

  {
    std::vector<atom*> testAtoms = World::getInstance().getAllAtoms(!NoAtoms());
    CPPUNIT_ASSERT_EQUAL( true , hasAllAtoms(testAtoms,atomIds));
    CPPUNIT_ASSERT_EQUAL( true , hasNoDuplicateAtoms(testAtoms));
  }
  // exclude and include some atoms
  {
    std::vector<atom*> testAtoms = World::getInstance().getAllAtoms(AllAtoms()&&(!AtomById(atomIds[ATOM_COUNT/2])));
    std::set<atomId_t> excluded;
    excluded.insert(atomIds[ATOM_COUNT/2]);
    CPPUNIT_ASSERT_EQUAL( true , hasAllAtoms(testAtoms,atomIds,excluded));
    CPPUNIT_ASSERT_EQUAL( true , hasNoDuplicateAtoms(testAtoms));
    CPPUNIT_ASSERT_EQUAL( (size_t)(ATOM_COUNT-1), testAtoms.size());
  }

  {
    std::vector<atom*> testAtoms = World::getInstance().getAllAtoms(NoAtoms()||(AtomById(atomIds[ATOM_COUNT/2])));
    CPPUNIT_ASSERT_EQUAL( (size_t)1, testAtoms.size());
    CPPUNIT_ASSERT_EQUAL( atomIds[ATOM_COUNT/2], testAtoms[0]->getId());
  }
}
