/*
 * Project: MoleCuilder
 * Description: creates and alters molecular systems
 * Copyright (C)  2012 University of Bonn. All rights reserved.
 * Please see the COPYING file or "Copyright notice" in builder.cpp for details.
 * 
 *
 *   This file is part of MoleCuilder.
 *
 *    MoleCuilder is free software: you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation, either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    MoleCuilder is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with MoleCuilder.  If not, see . 
 */
/*
 * SaturationPotential.cpp
 *
 *  Created on: Oct 11, 2012
 *      Author: heber
 */
// include config.h
#ifdef HAVE_CONFIG_H
#include 
#endif
#include "CodePatterns/MemDebug.hpp"
#include "SaturationPotential.hpp"
#include "CodePatterns/Assert.hpp"
#include "CodePatterns/Log.hpp"
#include "Potentials/helpers.hpp"
SaturationPotential::SaturationPotential(
    const double _saturation_cutoff,
    boost::function< std::vector(const argument_t &, const double)> &_triplefunction) :
  energy_offset(0.),
  triplefunction(_triplefunction),
  saturation_cutoff(_saturation_cutoff)
{}
SaturationPotential::SaturationPotential(
    const double _morse_spring_constant,
    const double _morse_equilibrium_distance,
    const double _morse_dissociation_energy,
    const double _angle_spring_constant,
    const double _angle_equilibrium_distance,
    const double _all_energy_offset,
    const double _saturation_cutoff,
    boost::function< std::vector(const argument_t &, const double)> &_triplefunction) :
  energy_offset(_all_energy_offset),
  triplefunction(_triplefunction),
  saturation_cutoff(_saturation_cutoff)
{
  parameters_t morse_params(morse.getParameterDimension());
  morse_params[PairPotential_Morse::spring_constant] = _morse_spring_constant;
  morse_params[PairPotential_Morse::equilibrium_distance] = _morse_equilibrium_distance;
  morse_params[PairPotential_Morse::dissociation_energy] = _morse_dissociation_energy;
  morse_params[PairPotential_Morse::energy_offset] = 0.;
  morse.setParameters(morse_params);
  parameters_t angle_params(angle.getParameterDimension());
  angle_params[PairPotential_Angle::spring_constant] = _angle_spring_constant;
  angle_params[PairPotential_Angle::equilibrium_distance] = _angle_equilibrium_distance;
  angle_params[PairPotential_Angle::energy_offset] = 0.;
  angle.setParameters(angle_params);
}
void SaturationPotential::setParameters(const parameters_t &_params)
{
  const size_t paramsDim = _params.size();
  ASSERT( paramsDim <= getParameterDimension(),
      "SaturationPotential::setParameters() - we need not more than "
      +toString(getParameterDimension())+" parameters.");
//    LOG(1, "INFO: Setting new SaturationPotential params: " << _params);
  // offsets
  if (paramsDim > all_energy_offset)
    energy_offset = _params[all_energy_offset];
  // Morse
  {
    parameters_t morse_params(morse.getParameters());
    if (paramsDim > morse_spring_constant)
      morse_params[PairPotential_Morse::spring_constant] = _params[morse_spring_constant];
    if (paramsDim > morse_equilibrium_distance)
      morse_params[PairPotential_Morse::equilibrium_distance] = _params[morse_equilibrium_distance];
    if (paramsDim > morse_dissociation_energy)
      morse_params[PairPotential_Morse::dissociation_energy] = _params[morse_dissociation_energy];
    morse_params[PairPotential_Morse::energy_offset] = 0.;
    morse.setParameters(morse_params);
  }
  // Angle
  {
    parameters_t angle_params(angle.getParameters());
    if (paramsDim > angle_spring_constant)
      angle_params[PairPotential_Angle::spring_constant] = _params[angle_spring_constant];
    if (paramsDim > angle_equilibrium_distance)
      angle_params[PairPotential_Angle::equilibrium_distance] = _params[angle_equilibrium_distance];
    angle_params[PairPotential_Angle::energy_offset] = 0.;
    angle.setParameters(angle_params);
  }
#ifndef NDEBUG
  parameters_t check_params(getParameters());
  check_params.resize(paramsDim); // truncate to same size
  ASSERT( check_params == _params,
      "SaturationPotential::setParameters() - failed, mismatch in to be set "
      +toString(_params)+" and set "+toString(check_params)+" params.");
#endif
}
SaturationPotential::parameters_t SaturationPotential::getParameters() const
{
  parameters_t params(getParameterDimension());
  const parameters_t morse_params = morse.getParameters();
  const parameters_t angle_params = angle.getParameters();
  params[all_energy_offset] = energy_offset;
  params[morse_spring_constant] = morse_params[PairPotential_Morse::spring_constant];
  params[morse_equilibrium_distance] = morse_params[PairPotential_Morse::equilibrium_distance];
  params[morse_dissociation_energy] = morse_params[PairPotential_Morse::dissociation_energy];
  params[angle_spring_constant] = angle_params[PairPotential_Angle::spring_constant];
  params[angle_equilibrium_distance] = angle_params[PairPotential_Angle::equilibrium_distance];
  return params;
}
SaturationPotential::results_t
SaturationPotential::operator()(
    const arguments_t &arguments
    ) const
{
  double result = 0.;
  for(arguments_t::const_iterator argiter = arguments.begin();
      argiter != arguments.end();
      ++argiter) {
    const argument_t &r_ij = *argiter;
    if ((r_ij.indices.first == 0)) { // first item must be the non-hydrogen
      arguments_t args(1, r_ij);
      // Morse contribution
      result += morse(args)[0];
      if (result != result)
        ELOG(1, "result is NAN.");
      // Angle contribution
      std::vector triples = triplefunction(r_ij, saturation_cutoff);
      args.resize(3, r_ij);
      for (std::vector::const_iterator iter = triples.begin();
          iter != triples.end(); ++iter) {
        ASSERT( iter->size() == 2,
            "SaturationPotential::function_derivative_c() - the triples result must contain exactly two distances.");
        const argument_t &r_ik = (*iter)[0];
        const argument_t &r_jk = (*iter)[1];
        args[1] = r_ik;
        args[2] = r_jk;
        result += .5*angle(args)[0];  // as we have all distances we get both jk and kj
        if (result != result)
          ELOG(1, "result is NAN.");
      }
    }
  }
  return std::vector(1, energy_offset + result);
}
SaturationPotential::derivative_components_t
SaturationPotential::derivative(
    const arguments_t &arguments
    ) const
{
  ASSERT( 0,
      "SaturationPotential::operator() - not implemented.");
  derivative_components_t result;
  return result;
}
SaturationPotential::results_t
SaturationPotential::parameter_derivative(
    const arguments_t &arguments,
    const size_t index
    ) const
{
  double result = 0.;
  if (index == all_energy_offset) {
    result = 1.;
  } else {
    for(arguments_t::const_iterator argiter = arguments.begin();
        argiter != arguments.end();
        ++argiter) {
      const argument_t &r_ij = *argiter;
      if ((r_ij.indices.first == 0)) { // first item must be the non-hydrogen
        arguments_t args(1, r_ij);
        switch (index) {
          case morse_spring_constant:
          {
            result += morse.parameter_derivative(args, PairPotential_Morse::spring_constant)[0];
            break;
          }
          case morse_equilibrium_distance:
          {
            result += morse.parameter_derivative(args, PairPotential_Morse::equilibrium_distance)[0];
            break;
          }
          case morse_dissociation_energy:
          {
            result += morse.parameter_derivative(args, PairPotential_Morse::dissociation_energy)[0];
            break;
          }
          default:
          {
            args.resize(3, r_ij);
            std::vector triples = triplefunction(r_ij, saturation_cutoff);
            for (std::vector::const_iterator iter = triples.begin();
                iter != triples.end(); ++iter) {
              ASSERT( iter->size() == 2,
                  "SaturationPotential::parameter_derivative() - the triples result must contain exactly two distances.");
              const argument_t &r_ik = (*iter)[0];
              ASSERT( r_ik.indices.first == r_ij.indices.first,
                  "SaturationPotential::parameter_derivative() - i not same in ij, ik.");
              const argument_t &r_jk = (*iter)[1];
              ASSERT( r_jk.indices.first == r_ij.indices.second,
                  "SaturationPotential::parameter_derivative() - j not same in ij, jk.");
              ASSERT( r_ik.indices.second == r_jk.indices.second,
                  "SaturationPotential::parameter_derivative() - k not same in ik, jk.");
              args[1] = r_ik;
              args[2] = r_jk;
              switch (index) {   // .5 due to we have all distances we get both jk and kj
                case angle_spring_constant:
                {
                  result += .5*angle.parameter_derivative(args, PairPotential_Angle::spring_constant)[0];
                  break;
                }
                case angle_equilibrium_distance:
                {
                  result += .5*angle.parameter_derivative(args, PairPotential_Angle::equilibrium_distance)[0];
                  break;
                }
                default:
                  break;
              }
            }
            break;
          }
        }
      }
    }
  }
  return SaturationPotential::results_t(1, result);
}