/*
 * Project: MoleCuilder
 * Description: creates and alters molecular systems
 * Copyright (C)  2012 University of Bonn. All rights reserved.
 * Please see the COPYING file or "Copyright notice" in builder.cpp for details.
 * 
 *
 *   This file is part of MoleCuilder.
 *
 *    MoleCuilder is free software: you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation, either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    MoleCuilder is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with MoleCuilder.  If not, see . 
 */
/*
 * SaturationPotential.cpp
 *
 *  Created on: Oct 11, 2012
 *      Author: heber
 */
// include config.h
#ifdef HAVE_CONFIG_H
#include 
#endif
#include "CodePatterns/MemDebug.hpp"
#include "SaturationPotential.hpp"
#include 
#include  // for 'map_list_of()'
#include 
#include 
#include 
#include "CodePatterns/Assert.hpp"
#include "CodePatterns/Log.hpp"
#include "FunctionApproximation/Extractors.hpp"
#include "FunctionApproximation/TrainingData.hpp"
#include "Potentials/helpers.hpp"
#include "Potentials/ParticleTypeCheckers.hpp"
class Fragment;
using namespace boost::assign;
// static definitions
const SaturationPotential::ParameterNames_t
SaturationPotential::ParameterNames =
      boost::assign::list_of
      ("all_energy_offset")
      ("")
      ("")
      ("")
      ("")
      ("")
    ;
const std::string SaturationPotential::potential_token("saturation");
SaturationPotential::SaturationPotential(
    const ParticleTypes_t &_ParticleTypes) :
  SerializablePotential(_ParticleTypes),
  morse(_ParticleTypes),
  angle(addSaturationType(_ParticleTypes)),
  energy_offset(0.)
{
  // have some decent defaults for parameter_derivative checking
  // Morse and Angle have their own defaults, offset is set
  ASSERT( _ParticleTypes.size() == (size_t)2,
      "SaturationPotential::SaturationPotential() - exactly two types must be given.");
//  ASSERT( _ParticleTypes[1] == 1,
//      "SaturationPotential::SaturationPotential() - second type must be hydrogen.");
}
SaturationPotential::SaturationPotential(
    const ParticleTypes_t &_ParticleTypes,
    const double _all_energy_offset,
    const double _morse_spring_constant,
    const double _morse_equilibrium_distance,
    const double _morse_dissociation_energy,
    const double _angle_spring_constant,
    const double _angle_equilibrium_distance) :
  SerializablePotential(_ParticleTypes),
  morse(_ParticleTypes),
  angle(addSaturationType(_ParticleTypes)),
  energy_offset(_all_energy_offset)
{
  ASSERT( _ParticleTypes.size() == (size_t)2,
      "SaturationPotential::SaturationPotential() - exactly two types must be given.");
//  ASSERT( _ParticleTypes[1] == 1,
//      "SaturationPotential::SaturationPotential() - second type must be hydrogen.");
  parameters_t morse_params(morse.getParameterDimension());
  morse_params[PairPotential_Morse::spring_constant] = _morse_spring_constant;
  morse_params[PairPotential_Morse::equilibrium_distance] = _morse_equilibrium_distance;
  morse_params[PairPotential_Morse::dissociation_energy] = _morse_dissociation_energy;
  morse_params[PairPotential_Morse::energy_offset] = 0.;
  morse.setParameters(morse_params);
  parameters_t angle_params(angle.getParameterDimension());
  angle_params[PairPotential_Angle::spring_constant] = _angle_spring_constant;
  angle_params[PairPotential_Angle::equilibrium_distance] = _angle_equilibrium_distance;
  angle_params[PairPotential_Angle::energy_offset] = 0.;
  angle.setParameters(angle_params);
}
void SaturationPotential::setParameters(const parameters_t &_params)
{
  const size_t paramsDim = _params.size();
  ASSERT( paramsDim <= getParameterDimension(),
      "SaturationPotential::setParameters() - we need not more than "
      +toString(getParameterDimension())+" parameters.");
//    LOG(1, "INFO: Setting new SaturationPotential params: " << _params);
  // offsets
  if (paramsDim > all_energy_offset)
    energy_offset = _params[all_energy_offset];
  // Morse
  {
    parameters_t morse_params(morse.getParameters());
    if (paramsDim > morse_spring_constant)
      morse_params[PairPotential_Morse::spring_constant] = _params[morse_spring_constant];
    if (paramsDim > morse_equilibrium_distance)
      morse_params[PairPotential_Morse::equilibrium_distance] = _params[morse_equilibrium_distance];
    if (paramsDim > morse_dissociation_energy)
      morse_params[PairPotential_Morse::dissociation_energy] = _params[morse_dissociation_energy];
    morse_params[PairPotential_Morse::energy_offset] = 0.;
    morse.setParameters(morse_params);
  }
  // Angle
  {
    parameters_t angle_params(angle.getParameters());
    if (paramsDim > angle_spring_constant)
      angle_params[PairPotential_Angle::spring_constant] = _params[angle_spring_constant];
    if (paramsDim > angle_equilibrium_distance)
      angle_params[PairPotential_Angle::equilibrium_distance] = _params[angle_equilibrium_distance];
    angle_params[PairPotential_Angle::energy_offset] = 0.;
    angle.setParameters(angle_params);
  }
#ifndef NDEBUG
  parameters_t check_params(getParameters());
  check_params.resize(paramsDim); // truncate to same size
  ASSERT( check_params == _params,
      "SaturationPotential::setParameters() - failed, mismatch in to be set "
      +toString(_params)+" and set "+toString(check_params)+" params.");
#endif
}
SaturationPotential::parameters_t SaturationPotential::getParameters() const
{
  parameters_t params(getParameterDimension());
  const parameters_t morse_params = morse.getParameters();
  const parameters_t angle_params = angle.getParameters();
  params[all_energy_offset] = energy_offset;
  params[morse_spring_constant] = morse_params[PairPotential_Morse::spring_constant];
  params[morse_equilibrium_distance] = morse_params[PairPotential_Morse::equilibrium_distance];
  params[morse_dissociation_energy] = morse_params[PairPotential_Morse::dissociation_energy];
  params[angle_spring_constant] = angle_params[PairPotential_Angle::spring_constant];
  params[angle_equilibrium_distance] = angle_params[PairPotential_Angle::equilibrium_distance];
  return params;
}
void SaturationPotential::stream_to(std::ostream &ost) const
{
  morse.stream_to(ost);
  ost << std::endl;
  angle.stream_to(ost);
}
void SaturationPotential::stream_from(std::istream &ist)
{
  morse.stream_from(ist);
  ist >> ws;
  angle.stream_from(ist);
}
std::vector
triplefunction(
    const argument_t &argument,
    const FunctionModel::arguments_t& args)
{
  const size_t firstindex = argument.indices.first;
  const size_t secondindex = argument.indices.second;
//  LOG(2, "DEBUG: first index is " << firstindex << ", second index is " << secondindex << ".");
  // place all arguments that share either index into a lookup map
  typedef std::map< size_t, FunctionModel::arguments_t::const_iterator > IndexLookup_t;
  IndexLookup_t LookuptoFirst;
  IndexLookup_t LookuptoSecond;
  for (FunctionModel::arguments_t::const_iterator iter = args.begin();
      iter != args.end();
      ++iter) {
    if (((*iter).indices.first == argument.indices.first)
        && ((*iter).indices.second == argument.indices.second))
      continue;
    if (firstindex == (*iter).indices.first) {
      LookuptoFirst.insert( std::make_pair( (*iter).indices.second, iter) );
    }
    else if (firstindex == (*iter).indices.second) {
      LookuptoFirst.insert( std::make_pair( (*iter).indices.first, iter) );
    }
    if (secondindex == (*iter).indices.first) {
      LookuptoSecond.insert( std::make_pair( (*iter).indices.second, iter) );
    }
    else if (secondindex == (*iter).indices.second) {
      LookuptoSecond.insert( std::make_pair((*iter).indices.first, iter) );
    }
  }
//  {
//    std::stringstream lookupstream;
//    for (IndexLookup_t::const_iterator iter = LookuptoFirst.begin();
//        iter != LookuptoFirst.end();
//        ++iter) {
//      lookupstream << "(" << iter->first << "," << *(iter->second) << ") ";
//    }
//    LOG(2, "DEBUG: LookupToFirst is " << lookupstream.str() << ".");
//  }
//  {
//    std::stringstream lookupstream;
//    for (IndexLookup_t::const_iterator iter = LookuptoSecond.begin();
//        iter != LookuptoSecond.end();
//        ++iter) {
//      lookupstream << "(" << iter->first << "," << *(iter->second) << ") ";
//    }
//    LOG(2, "DEBUG: LookuptoSecond is " << lookupstream.str() << ".");
//  }
  // now go through the first lookup as the second argument and pick the
  // corresponding third argument by the matching index
  std::vector results;
  for (IndexLookup_t::const_iterator iter = LookuptoFirst.begin();
      iter != LookuptoFirst.end();
      ++iter) {
    IndexLookup_t::const_iterator otheriter = LookuptoSecond.find(iter->first);
    ASSERT( otheriter != LookuptoSecond.end(),
        "triplefunction() - cannot find index "+toString(iter->first)
        +" in LookupToSecond");
    FunctionModel::arguments_t result(1, argument);
    result.reserve(3);
    result.push_back(*(iter->second));
    result.push_back(*(otheriter->second));
    results.push_back(result);
  }
  return results;
}
SaturationPotential::results_t
SaturationPotential::operator()(
    const arguments_t &arguments
    ) const
{
  double result = 0.;
  const ParticleTypes_t &morse_types = morse.getParticleTypes();
  const ParticleTypes_t &angle_types = angle.getParticleTypes();
  double multiplicity = 1.;
  if ((angle_types[0] == angle_types[1]) && (angle_types[1] == angle_types[2]))
    multiplicity = 1./6.;
  else if ((angle_types[0] == angle_types[1])  
        || (angle_types[1] == angle_types[2])
        || (angle_types[0] == angle_types[2]))
    multiplicity = .5;
  for(arguments_t::const_iterator argiter = arguments.begin();
      argiter != arguments.end();
      ++argiter) {
    const argument_t &r_ij = *argiter;
    if (((r_ij.types.first == morse_types[0]) && (r_ij.types.second == morse_types[1]))
        || ((r_ij.types.first == morse_types[1]) && (r_ij.types.second == morse_types[0]))) {
      arguments_t args(1, r_ij);
      // Morse contribution
      const double tmp = morse(args)[0];
//      LOG(3, "DEBUG: Morse yields " << tmp << " for << " << r_ij << ".");
      result += tmp;
      if (result != result)
        ELOG(1, "result is NAN.");
    }
    if (((r_ij.types.first == angle_types[0]) && (r_ij.types.second == angle_types[1]))
        || ((r_ij.types.first == angle_types[1]) && (r_ij.types.second == angle_types[0]))) {
      // Angle contribution
      {
        typedef std::vector tripleargs_t;
        tripleargs_t tripleargs =
            triplefunction(r_ij, arguments);
        for (tripleargs_t::const_iterator iter = tripleargs.begin();
            iter != tripleargs.end();
            ++iter) {
          FunctionModel::arguments_t tempargs =
              Extractors::reorderArgumentsByParticleTypes(*iter, angle.getParticleTypes());
          // We get both angles, e.g. 0-4-1 and 1-4-0, hence multiply with 0.5
          const double tmp = multiplicity*angle(tempargs)[0];  // as we have all distances we get both jk and kj
//          LOG(3, "DEBUG: angle yields " << tmp << " for << " << tempargs << ".");
          result += tmp;
          if (result != result)
            ELOG(1, "result is NAN.");
        }
      }
    }
  }
  return std::vector(1, energy_offset + result);
}
SaturationPotential::derivative_components_t
SaturationPotential::derivative(
    const arguments_t &arguments
    ) const
{
  ASSERT( 0,
      "SaturationPotential::operator() - not implemented.");
  derivative_components_t result;
  return result;
}
SaturationPotential::results_t
SaturationPotential::parameter_derivative(
    const arguments_t &arguments,
    const size_t index
    ) const
{
  double result = 0.;
  switch (index) {
  case all_energy_offset:
    result = 1.;
    break;
  case morse_spring_constant:
  case morse_equilibrium_distance:
  case morse_dissociation_energy:
    {
      const ParticleTypes_t &morse_types = morse.getParticleTypes();
      for(arguments_t::const_iterator argiter = arguments.begin();
          argiter != arguments.end();
          ++argiter) {
        const argument_t &r_ij = *argiter;
        if (((r_ij.types.first == morse_types[0]) && (r_ij.types.second == morse_types[1]))
            || ((r_ij.types.first == morse_types[1]) && (r_ij.types.second == morse_types[0]))) {
          arguments_t args(1, r_ij);
         double tmp = 0.;
          switch (index) {
            case morse_spring_constant:
              tmp += morse.parameter_derivative(args, PairPotential_Morse::spring_constant)[0];
              break;
            case morse_equilibrium_distance:
              tmp += morse.parameter_derivative(args, PairPotential_Morse::equilibrium_distance)[0];
              break;
            case morse_dissociation_energy:
              tmp += morse.parameter_derivative(args, PairPotential_Morse::dissociation_energy)[0];
              break;
            default:
              ASSERT(0, "SaturationPotential::parameter_derivative() - We cannot get here.");
              break;
          }
//          LOG(2, "DEBUG: morse yields " << tmp << " for << " << args << ".");
          result += tmp;
        }
      }
    }
    break;
  case angle_spring_constant:
  case angle_equilibrium_distance:
    {
      const ParticleTypes_t &angle_types = angle.getParticleTypes();
      double multiplicity = 1.;
      if ((angle_types[0] == angle_types[1]) && (angle_types[1] == angle_types[2]))
        multiplicity = 1./6.;
      else if ((angle_types[0] == angle_types[1])  
            || (angle_types[1] == angle_types[2])
            || (angle_types[0] == angle_types[2]))
        multiplicity = .5;
      for(arguments_t::const_iterator argiter = arguments.begin();
          argiter != arguments.end();
          ++argiter) {
        const argument_t &r_ij = *argiter;
        if (((r_ij.types.first == angle_types[0]) && (r_ij.types.second == angle_types[1]))
            || ((r_ij.types.first == angle_types[1]) && (r_ij.types.second == angle_types[0]))) {
          typedef std::vector tripleargs_t;
          tripleargs_t tripleargs =
              triplefunction(r_ij, arguments);
          for (tripleargs_t::const_iterator iter = tripleargs.begin();
              iter != tripleargs.end();
              ++iter) {
            FunctionModel::arguments_t tempargs =
                Extractors::reorderArgumentsByParticleTypes(*iter, angle.getParticleTypes());
            // We get both angles, e.g. 0-4-1 and 1-4-0, hence multiply with 0.5
            double tmp = 0.;
            if (index == angle_spring_constant)
              tmp += multiplicity*angle.parameter_derivative(tempargs, PairPotential_Angle::spring_constant)[0];
            else if (index == angle_equilibrium_distance)
              tmp += multiplicity*angle.parameter_derivative(tempargs, PairPotential_Angle::equilibrium_distance)[0];
//            LOG(2, "DEBUG: angle yields " << tmp << " for << " << tempargs << ".");
            result += tmp;
            if (result != result)
              ELOG(1, "result is NAN.");
          }
        }
      }
    }
    break;
  default:
    ASSERT( 0, "SaturationPotential::parameter_derivative() - impossible to get here.");
    break;
  }
  return SaturationPotential::results_t(1, result);
}
const SaturationPotential::ParticleTypes_t
SaturationPotential::symmetrizeTypes(const ParticleTypes_t &_ParticleTypes)
{
  ASSERT( _ParticleTypes.size() == (size_t)2,
      "SaturationPotential::symmetrizeTypes() - require initial _ParticleTypes with two elements.");
//  // insert before couple
//  ParticleTypes_t types(1, _ParticleTypes[1]);
//  types.insert(types.end(), _ParticleTypes.begin(), _ParticleTypes.end());
  // insert after the couple
  ParticleTypes_t types(_ParticleTypes);
  types.push_back( _ParticleTypes.back() );
  ASSERT( types.size() == (size_t)3,
      "SaturationPotential::symmetrizeTypes() - failed to generate three types for angle.");
  return types;
}
const SaturationPotential::ParticleTypes_t
SaturationPotential::addSaturationType(const ParticleTypes_t &_ParticleTypes)
{
  ParticleTypes_t types(_ParticleTypes);
  types.push_back( ParticleType_t(1) );
  return types;
}
FunctionModel::extractor_t
SaturationPotential::getFragmentSpecificExtractor() const
{
//  Fragment::charges_t charges;
//  charges.resize(getParticleTypes().size());
//  std::transform(getParticleTypes().begin(), getParticleTypes().end(),
//      charges.begin(), boost::lambda::_1);	  
  FunctionModel::extractor_t returnfunction;
//  if (charges[0] == charges[1]) {
//    // In case both types are equal there is only a single pair of possible
//    // type combinations.
//     returnfunction =
//        boost::bind(&Extractors::gatherAllDistancesFromFragment,
//            boost::bind(&Fragment::getPositions, _1),
//            boost::bind(&Fragment::getCharges, _1),
//            charges, // is only temporarily created, hence copy
//            _2);
//  } else {
    // we have to chain here a rather complex "tree" of functions
    // as we only have a couple of ParticleTypes but need to get
    // all possible three pairs of the set of the two types.
    // Finally, we also need to arrange them in correct order
    // (for PairPotentiale_Angle).
//    charges_t firstpair(2, boost::cref(charges[0]));
    // only that saturation potential never has its middle element twice!
    // hence, we skip the firstpair but keep the code for later generalization
//    Fragment::charges_t secondpair(2, boost::cref(charges[1]));
//    const Fragment::charges_t &thirdpair = charges;
    Fragment::charges_t charges_angle;
    {
      charges_angle.resize(angle.getParticleTypes().size());
      std::transform(angle.getParticleTypes().begin(), angle.getParticleTypes().end(),
          charges_angle.begin(), boost::lambda::_1);
    }
    Fragment::charges_t charges_morse;
    {
      charges_morse.resize(morse.getParticleTypes().size());
      std::transform(morse.getParticleTypes().begin(), morse.getParticleTypes().end(),
          charges_morse.begin(), boost::lambda::_1);
    }
    returnfunction =
//        boost::bind(&Extractors::reorderArgumentsByParticleTypes,
          boost::bind(&Extractors::combineArguments,
//            boost::bind(&Extractors::combineArguments,
              boost::bind(&Extractors::gatherAllDistancesFromFragment,
                  boost::bind(&Fragment::getPositions, _1),
                  boost::bind(&Fragment::getCharges, _1),
                  charges_angle,  // no crefs here as are temporaries!
                  _2),
              boost::bind(&Extractors::gatherAllDistancesFromFragment,
                  boost::bind(&Fragment::getPositions, _1),
                  boost::bind(&Fragment::getCharges, _1),
                  charges_morse,  // no crefs here as are temporaries!
                  _2)
//            )
//            boost::bind(&Extractors::gatherAllDistancesFromFragment,
//                boost::bind(&Fragment::getPositions, _1),
//                boost::bind(&Fragment::getCharges, _1),
//                boost::cref(thirdpair), // only the last one is no temporary
//                _2)
//          ),
//          boost::bind(&PairPotential_Angle::getParticleTypes, boost::cref(angle))
        );
//  }
  return returnfunction;
}
void
SaturationPotential::setParametersToRandomInitialValues(
    const TrainingData &data)
{
  energy_offset = data.getTrainingOutputAverage()[0];
  morse.setParametersToRandomInitialValues(data);
  angle.setParametersToRandomInitialValues(data);
}