/*
 * Project: MoleCuilder
 * Description: creates and alters molecular systems
 * Copyright (C)  2014 Frederik Heber. All rights reserved.
 *
 *
 *   This file is part of MoleCuilder.
 *
 *    MoleCuilder is free software: you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation, either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    MoleCuilder is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with MoleCuilder.  If not, see <http://www.gnu.org/licenses/>.
 */

/*
 * SaturationDistanceMaximizer.cpp
 *
 *  Created on: Jul 27, 2014
 *      Author: heber
 */

// include config.h
#ifdef HAVE_CONFIG_H
#include <config.h>
#endif

#include "CodePatterns/MemDebug.hpp"

#include "SaturationDistanceMaximizer.hpp"

#include <cmath>
#include <gsl/gsl_multimin.h>
#include <gsl/gsl_vector.h>

#include "CodePatterns/Log.hpp"

double
func(const gsl_vector *x, void *adata)
{
  // get the object whose functions we call
  SaturationDistanceMaximizer::Advocate *maximizer =
      static_cast<SaturationDistanceMaximizer::Advocate *>(adata);
  // set alphas
  maximizer->setAlphas(x);
  // calculate function value and return
  return maximizer->calculatePenality();
}

void
jacf(const gsl_vector *x, void *adata, gsl_vector *g)
{
  // get the object whose functions we call
  SaturationDistanceMaximizer::Advocate *maximizer =
      static_cast<SaturationDistanceMaximizer::Advocate *>(adata);
  // set alphas
  maximizer->setAlphas(x);
  // calculate function gradient and return
  std::vector<double> gradient = maximizer->calculatePenalityGradient();
  for (unsigned int i=0;i<gradient.size();++i)
    gsl_vector_set(g,i,gradient[i]);
}

void
funcjacf(const gsl_vector *x, void *adata, double *f, gsl_vector *g)
{
  // get the object whose functions we call
  SaturationDistanceMaximizer::Advocate *maximizer =
      static_cast<SaturationDistanceMaximizer::Advocate *>(adata);
  // set alphas
  maximizer->setAlphas(x);
  // calculate function value and return
  *f = maximizer->calculatePenality();
  std::vector<double> gradient = maximizer->calculatePenalityGradient();
  for (unsigned int i=0;i<gradient.size();++i)
    gsl_vector_set(g,i,gradient[i]);
}

std::vector<double> SaturationDistanceMaximizer::getAlphas() const
{
  std::vector<double> alphas;
  PositionContainers_t::iterator containeriter = PositionContainers.begin();
  for (unsigned int i=0; i<PositionContainers.size(); ++i, ++containeriter)
    alphas.push_back( (*containeriter)->alpha );
  return alphas;
}

void SaturationDistanceMaximizer::setAlphas(const gsl_vector *x)
{
  PositionContainers_t::iterator containeriter = PositionContainers.begin();
  for (unsigned int i=0; i<PositionContainers.size(); ++i, ++containeriter)
    (*containeriter)->alpha = gsl_vector_get(x,i);
}

void SaturationDistanceMaximizer::operator()()
{
  // some control constants
  const double tolerance = 1e-6;
  const unsigned int MAXITERATIONS = 100;

  const gsl_multimin_fdfminimizer_type *T;
  gsl_multimin_fdfminimizer *s;

  gsl_vector *x;
  gsl_multimin_function_fdf my_func;

  const unsigned int N = PositionContainers.size();
  my_func.n = N;
  my_func.f = &func;
  my_func.df = &jacf;
  my_func.fdf = &funcjacf;
  SaturationDistanceMaximizer::Advocate* const advocate = getAdvocate();
  my_func.params = advocate;

  // allocate argument and set to zero
  x = gsl_vector_alloc(N);
  for (unsigned int i=0;i<N;++i)
    gsl_vector_set(x, i, 0.);

  // set minimizer and allocate workspace
  T = gsl_multimin_fdfminimizer_vector_bfgs;
  s = gsl_multimin_fdfminimizer_alloc (T, N);

  // initialize minimizer
  gsl_multimin_fdfminimizer_set(s, &my_func, x, 0.1, tolerance); /* tolerance */

  size_t iter = 0;
  int status = 0;
  do {
    ++iter;
    status = gsl_multimin_fdfminimizer_iterate(s);

    if (status)
      break;

    status = gsl_multimin_test_gradient(s->gradient, tolerance);

  } while ((status = GSL_CONTINUE) && (iter < MAXITERATIONS));

  // set to solution
  setAlphas(s->x);

  // print solution
  if (DoLog(4)) {
    std::stringstream sstream;
    sstream << "DEBUG: Minimal alphas are ";
    for (unsigned int i=0;i<N;++i)
      sstream << gsl_vector_get(s->x,i) << ((i!= N-1) ? "," : "");
    LOG(4, sstream.str());
  }

  // free memory
  gsl_multimin_fdfminimizer_free(s);
  my_func.params = NULL;
  delete advocate;
  gsl_vector_free(x);
}

SaturationDistanceMaximizer::Advocate* SaturationDistanceMaximizer::getAdvocate()
{
  return new Advocate(*this);
}

SaturationDistanceMaximizer::position_bins_t
SaturationDistanceMaximizer::getAllPositionBins() const
{
  position_bins_t position_bins;
  position_bins.reserve(PositionContainers.size());
  for (PositionContainers_t::const_iterator containeriter = PositionContainers.begin();
      containeriter != PositionContainers.end(); ++containeriter)
    position_bins.push_back( (*containeriter)->getPositions() );

  return position_bins;
}

double SaturationDistanceMaximizer::calculatePenality() const
{
  double penalty = 0.;

  LOG(6, "DEBUG: Current alphas are " << getAlphas());

  // gather all positions
  position_bins_t position_bins = getAllPositionBins();

  // go through both bins (but with i<j)
  for (position_bins_t::const_iterator firstbiniter = position_bins.begin();
      firstbiniter != position_bins.end(); ++firstbiniter) {
    for (position_bins_t::const_iterator secondbiniter = firstbiniter;
        secondbiniter != position_bins.end(); ++secondbiniter) {
      if (firstbiniter == secondbiniter)
        continue;

      // then in each bin take each position
      for (SaturatedBond::positions_t::const_iterator firstpositioniter = firstbiniter->begin();
          firstpositioniter != firstbiniter->end(); ++firstpositioniter) {
        for (SaturatedBond::positions_t::const_iterator secondpositioniter = secondbiniter->begin();
            secondpositioniter != secondbiniter->end(); ++secondpositioniter) {
          // Both iters are from different bins, can never be the same.
          // We do not penalize over positions from same bin as their positions
          // are fixed.

          // We penalize by one over the squared distance
          penalty += 1./(firstpositioniter->DistanceSquared(*secondpositioniter));
        }
      }
    }
  }

  LOG(4, "DEBUG: Penalty is " << penalty);

  return penalty;
}

#ifdef HAVE_INLINE
inline
#else
static
#endif
size_t calculateHydrogenNo(
    const SaturatedBond::positions_t::const_iterator &_start,
    const SaturatedBond::positions_t::const_iterator &_current)
{
  const size_t HydrogenNo = std::distance(_start, _current);
  ASSERT( (HydrogenNo >= 0) && (HydrogenNo <= 2),
      "calculatePenalityGradient() - hydrogen no not in [0,2].");
  return HydrogenNo;
}

std::vector<double> SaturationDistanceMaximizer::calculatePenalityGradient() const
{
  // gather all positions
  const position_bins_t position_bins = getAllPositionBins();
  LOG(6, "DEBUG: Current alphas are " << getAlphas());

  std::vector<double> gradient(position_bins.size(), 0.);

  std::vector<double>::iterator biniter = gradient.begin();
  PositionContainers_t::const_iterator bonditer = PositionContainers.begin();
  position_bins_t::const_iterator firstbiniter = position_bins.begin();
  // go through each bond/gradient component/alpha
  for(; biniter != gradient.end(); ++biniter, ++bonditer, ++firstbiniter) {
    LOG(5, "DEBUG: Current bond is " << **bonditer << ", current bin is #"
        << std::distance(gradient.begin(), biniter) << ", set of positions are "
        << *firstbiniter);
    // skip bin if it belongs to a degree-1 bond (no alpha dependency here)
    if ((*bonditer)->saturated_bond.getDegree() == 1) {
      LOG(6, "DEBUG: Skipping due to degree 1.");
      continue;
    }

    // in the bin go through each position
    for (SaturatedBond::positions_t::const_iterator firstpositioniter = firstbiniter->begin();
        firstpositioniter != firstbiniter->end(); ++firstpositioniter) {
      LOG(6, "DEBUG: Current position is " << *firstpositioniter);

      // count the hydrogen we are looking at: Each is placed at a different position!
      const size_t HydrogenNo =
          calculateHydrogenNo(firstbiniter->begin(), firstpositioniter);
      const double alpha = (*bonditer)->alpha
          + (double)HydrogenNo * 2.*M_PI/(double)(*bonditer)->saturated_bond.getDegree();
      LOG(6, "DEBUG: HydrogenNo is " << HydrogenNo << ", alpha is " << alpha);

      // and go through each other bin
      for (position_bins_t::const_iterator secondbiniter = position_bins.begin();
          secondbiniter != position_bins.end(); ++secondbiniter) {
        // distance between hydrogens in same bin is not affected by the angle
//        if (firstbiniter == secondbiniter)
//          continue;

        // in the other bin go through each position
        for (SaturatedBond::positions_t::const_iterator secondpositioniter = secondbiniter->begin();
            secondpositioniter != secondbiniter->end(); ++secondpositioniter) {
          if (firstpositioniter == secondpositioniter) {
            LOG(7, "DEBUG: Skipping due to same positions.");
            continue;
          }
          LOG(7, "DEBUG: Second position is " << *secondpositioniter);

          // iters are from different bins, can never be the same
          const Vector distance = *firstpositioniter - *secondpositioniter;
          const double temp = -2./pow(distance.NormSquared(), 2);
          const Vector tempVector =
              (-sin(alpha)*(*bonditer)->vector_a)
              +(cos(alpha)*(*bonditer)->vector_b);
          const double result = temp * (distance.ScalarProduct(tempVector));
          *biniter += 2.*result; //for x_i and x_j
          LOG(7, "DEBUG: Total is " << result << ", temp is " << temp << ", tempVector is " << tempVector
              << ", and bondVector is " << distance << ": bin = " << *biniter);
        }
      }
    }
  }

  LOG(4, "DEBUG: Gradient of penalty is " << gradient);

  return gradient;
}

