Ignore:
File:
1 edited

Legend:

Unmodified
Added
Removed
  • src/Actions/PotentialAction/FitPotentialAction.cpp

    r707a2b r0ea063  
    5555#include "Fragmentation/Homology/HomologyGraph.hpp"
    5656#include "Fragmentation/Summation/SetValues/Fragment.hpp"
    57 #include "FunctionApproximation/Extractors.hpp"
    58 #include "FunctionApproximation/FunctionApproximation.hpp"
    59 #include "FunctionApproximation/FunctionModel.hpp"
    60 #include "FunctionApproximation/TrainingData.hpp"
    61 #include "FunctionApproximation/writeDistanceEnergyTable.hpp"
    62 #include "Potentials/CompoundPotential.hpp"
    63 #include "Potentials/Exceptions.hpp"
    64 #include "Potentials/PotentialDeserializer.hpp"
     57#include "Potentials/EmpiricalPotential.hpp"
    6558#include "Potentials/PotentialFactory.hpp"
    6659#include "Potentials/PotentialRegistry.hpp"
    6760#include "Potentials/PotentialSerializer.hpp"
     61#include "Potentials/PotentialTrainer.hpp"
    6862#include "Potentials/SerializablePotential.hpp"
     63#include "World.hpp"
    6964
    7065using namespace MoleCuilder;
     
    7570/** =========== define the function ====================== */
    7671
    77 HomologyGraph getFirstGraphwithSpecifiedElements(
    78     const HomologyContainer &homologies,
    79     const SerializablePotential::ParticleTypes_t &types)
    80 {
    81   ASSERT( !types.empty(),
    82       "getFirstGraphwithSpecifiedElements() - charges is empty?");
    83   // create charges
    84   Fragment::charges_t charges;
    85   charges.resize(types.size());
    86   std::transform(types.begin(), types.end(),
    87       charges.begin(), boost::lambda::_1);
    88   // convert into count map
    89   Extractors::elementcounts_t counts_per_charge =
    90       Extractors::_detail::getElementCounts(charges);
    91   ASSERT( !counts_per_charge.empty(),
    92       "getFirstGraphwithSpecifiedElements() - charge counts are empty?");
    93   LOG(2, "DEBUG: counts_per_charge is " << counts_per_charge << ".");
    94   // we want to check each (unique) key only once
    95   for (HomologyContainer::const_key_iterator iter = homologies.key_begin();
    96       iter != homologies.key_end(); iter = homologies.getNextKey(iter)) {
    97     // check if every element has the right number of counts
    98     Extractors::elementcounts_t::const_iterator countiter = counts_per_charge.begin();
    99     for (; countiter != counts_per_charge.end(); ++countiter)
    100       if (!(*iter).hasTimesAtomicNumber(
    101           static_cast<size_t>(countiter->first),
    102           static_cast<size_t>(countiter->second))
    103           )
    104         break;
    105     if( countiter == counts_per_charge.end())
    106       return *iter;
    107   }
    108   return HomologyGraph();
    109 }
    110 
    111 SerializablePotential::ParticleTypes_t getNumbersFromElements(
    112     const std::vector<const element *> &fragment)
    113 {
    114   SerializablePotential::ParticleTypes_t fragmentnumbers;
    115   std::transform(fragment.begin(), fragment.end(), std::back_inserter(fragmentnumbers),
    116       boost::bind(&element::getAtomicNumber, _1));
    117   return fragmentnumbers;
    118 }
    119 
    120 
    12172ActionState::ptr PotentialFitPotentialAction::performCall() {
    12273  // fragment specifies the homology fragment to use
    12374  SerializablePotential::ParticleTypes_t fragmentnumbers =
    124       getNumbersFromElements(params.fragment.get());
     75      PotentialTrainer::getNumbersFromElements(params.fragment.get());
    12576
    12677  // either charges and a potential is specified or a file
    127   if (boost::filesystem::exists(params.potential_file.get())) {
    128     std::ifstream returnstream(params.potential_file.get().string().c_str());
    129     if (returnstream.good()) {
    130       try {
    131         PotentialDeserializer deserialize(returnstream);
    132         deserialize();
    133       } catch (SerializablePotentialMissingValueException &e) {
    134         if (const std::string *key = boost::get_error_info<SerializablePotentialKey>(e))
    135           STATUS("Missing value when parsing information for potential "+*key+".");
    136         else
    137           STATUS("Missing value parsing information for potential with unknown key.");
    138         return Action::failure;
    139       } catch (SerializablePotentialIllegalKeyException &e) {
    140         if (const std::string *key = boost::get_error_info<SerializablePotentialKey>(e))
    141           STATUS("Illegal key parsing information for potential "+*key+".");
    142         else
    143           STATUS("Illegal key parsing information for potential with unknown key.");
    144         return Action::failure;
    145       }
    146     } else {
    147       STATUS("Failed to parse from "+params.potential_file.get().string()+".");
    148       return Action::failure;
     78  if (params.charges.get().empty()) {
     79    STATUS("No charges given!");
     80    return Action::failure;
     81  } else {
     82    // charges specify the potential type
     83    SerializablePotential::ParticleTypes_t chargenumbers =
     84        PotentialTrainer::getNumbersFromElements(params.charges.get());
     85
     86    LOG(0, "STATUS: I'm training now a " << params.potentialtype.get()
     87        << " potential on charges " << chargenumbers << " on data from World's homologies.");
     88
     89    // register desired potential and an additional constant one
     90    {
     91      EmpiricalPotential *potential =
     92          PotentialFactory::getInstance().createInstance(
     93              params.potentialtype.get(),
     94              chargenumbers);
     95      // check whether such a potential already exists
     96      const std::string potential_name = potential->getName();
     97      if (PotentialRegistry::getInstance().isPresentByName(potential_name)) {
     98        delete potential;
     99        potential = PotentialRegistry::getInstance().getByName(potential_name);
     100      } else
     101        PotentialRegistry::getInstance().registerInstance(potential);
    149102    }
    150     returnstream.close();
    151 
    152     LOG(0, "STATUS: I'm training now a set of potentials parsed from "
    153         << params.potential_file.get().string() << " on a fragment "
    154         << fragmentnumbers << " on data from World's homologies.");
    155 
    156   } else {
    157     if (params.charges.get().empty()) {
    158       STATUS("Neither charges nor potential file given!");
    159       return Action::failure;
    160     } else {
    161       // charges specify the potential type
    162       SerializablePotential::ParticleTypes_t chargenumbers =
    163           getNumbersFromElements(params.charges.get());
    164 
    165       LOG(0, "STATUS: I'm training now a " << params.potentialtype.get()
    166           << " potential on charges " << chargenumbers << " on data from World's homologies.");
    167 
    168       // register desired potential and an additional constant one
    169       {
    170         EmpiricalPotential *potential =
    171             PotentialFactory::getInstance().createInstance(
    172                 params.potentialtype.get(),
    173                 chargenumbers);
    174         // check whether such a potential already exists
    175         const std::string potential_name = potential->getName();
    176         if (PotentialRegistry::getInstance().isPresentByName(potential_name)) {
    177           delete potential;
    178           potential = PotentialRegistry::getInstance().getByName(potential_name);
    179         } else
    180           PotentialRegistry::getInstance().registerInstance(potential);
    181       }
    182       {
    183         EmpiricalPotential *constant =
    184             PotentialFactory::getInstance().createInstance(
    185                 std::string("constant"),
    186                 SerializablePotential::ParticleTypes_t());
    187         // check whether such a potential already exists
    188         const std::string constant_name = constant->getName();
    189         if (PotentialRegistry::getInstance().isPresentByName(constant_name)) {
    190           delete constant;
    191           constant = PotentialRegistry::getInstance().getByName(constant_name);
    192         } else
    193           PotentialRegistry::getInstance().registerInstance(constant);
    194       }
     103    {
     104      EmpiricalPotential *constant =
     105          PotentialFactory::getInstance().createInstance(
     106              std::string("constant"),
     107              SerializablePotential::ParticleTypes_t());
     108      // check whether such a potential already exists
     109      const std::string constant_name = constant->getName();
     110      if (PotentialRegistry::getInstance().isPresentByName(constant_name)) {
     111        delete constant;
     112        constant = PotentialRegistry::getInstance().getByName(constant_name);
     113      } else
     114        PotentialRegistry::getInstance().registerInstance(constant);
    195115    }
    196116  }
    197117
    198118  // parse homologies into container
    199   HomologyContainer &homologies = World::getInstance().getHomologies();
     119  const HomologyContainer &homologies = World::getInstance().getHomologies();
    200120
    201121  // first we try to look into the HomologyContainer
     
    211131
    212132  // then we ought to pick the right HomologyGraph ...
    213   const HomologyGraph graph = getFirstGraphwithSpecifiedElements(homologies,fragmentnumbers);
     133  const HomologyGraph graph =
     134      PotentialTrainer::getFirstGraphwithSpecifiedElements(homologies,fragmentnumbers);
    214135  if (graph != HomologyGraph()) {
    215136    LOG(1, "First representative graph containing fragment "
     
    220141  }
    221142
    222   // fit potential
    223   FunctionModel *model = new CompoundPotential(graph);
    224   ASSERT( model != NULL,
    225       "PotentialFitPotentialAction::performCall() - model is NULL.");
    226 
    227   /******************** TRAINING ********************/
    228   // fit potential
    229   FunctionModel::parameters_t bestparams(model->getParameterDimension(), 0.);
    230   {
    231     // Afterwards we go through all of this type and gather the distance and the energy value
    232     TrainingData data(model->getSpecificFilter());
    233     data(homologies.getHomologousGraphs(graph));
    234 
    235     // print distances and energies if desired for debugging
    236     if (!data.getTrainingInputs().empty()) {
    237       // print which distance is which
    238       size_t counter=1;
    239       if (DoLog(3)) {
    240         const FunctionModel::arguments_t &inputs = data.getAllArguments()[0];
    241         for (FunctionModel::arguments_t::const_iterator iter = inputs.begin();
    242             iter != inputs.end(); ++iter) {
    243           const argument_t &arg = *iter;
    244           LOG(3, "DEBUG: distance " << counter++ << " is between (#"
    245               << arg.indices.first << "c" << arg.types.first << ","
    246               << arg.indices.second << "c" << arg.types.second << ").");
    247         }
    248       }
    249 
    250       // print table
    251       if (params.training_file.get().string().empty()) {
    252         LOG(3, "DEBUG: I gathered the following training data:\n" <<
    253             _detail::writeDistanceEnergyTable(data.getDistanceEnergyTable()));
    254       } else {
    255         std::ofstream trainingstream(params.training_file.get().string().c_str());
    256         if (trainingstream.good()) {
    257           LOG(3, "DEBUG: Writing training data to file " <<
    258               params.training_file.get().string() << ".");
    259           trainingstream << _detail::writeDistanceEnergyTable(data.getDistanceEnergyTable());
    260         }
    261         trainingstream.close();
    262       }
    263     }
    264 
    265     if ((params.threshold.get() < 1) && (params.best_of_howmany.isSet()))
    266       ELOG(2, "threshold parameter always overrules max_runs, both are specified.");
    267     // now perform the function approximation by optimizing the model function
    268     FunctionApproximation approximator(data, *model);
    269     if (model->isBoxConstraint() && approximator.checkParameterDerivatives()) {
    270       double l2error = std::numeric_limits<double>::max();
    271       // seed with current time
    272       srand((unsigned)time(0));
    273       unsigned int runs=0;
    274       // threshold overrules max_runs
    275       const double threshold = params.threshold.get();
    276       const unsigned int max_runs = (threshold >= 1.) ?
    277           (params.best_of_howmany.isSet() ? params.best_of_howmany.get() : 1) : 0;
    278       LOG(1, "INFO: Maximum runs is " << max_runs << " and threshold set to " << threshold << ".");
    279       do {
    280         // generate new random initial parameter values
    281         model->setParametersToRandomInitialValues(data);
    282         LOG(1, "INFO: Initial parameters of run " << runs << " are "
    283             << model->getParameters() << ".");
    284         approximator(FunctionApproximation::ParameterDerivative);
    285         LOG(1, "INFO: Final parameters of run " << runs << " are "
    286             << model->getParameters() << ".");
    287         const double new_l2error = data.getL2Error(*model);
    288         if (new_l2error < l2error) {
    289           // store currently best parameters
    290           l2error = new_l2error;
    291           bestparams = model->getParameters();
    292           LOG(1, "STATUS: New fit from run " << runs
    293               << " has better error of " << l2error << ".");
    294         }
    295       } while (( ++runs < max_runs) || (l2error > threshold));
    296       // reset parameters from best fit
    297       model->setParameters(bestparams);
    298       LOG(1, "INFO: Best parameters with L2 error of "
    299           << l2error << " are " << model->getParameters() << ".");
    300     } else {
    301       STATUS("No required parameter derivatives for a box constraint minimization known.");
    302       return Action::failure;
    303     }
    304 
    305     // create a map of each fragment with error.
    306     HomologyContainer::range_t fragmentrange = homologies.getHomologousGraphs(graph);
    307     TrainingData::L2ErrorConfigurationIndexMap_t WorseFragmentMap =
    308         data.getWorstFragmentMap(*model, fragmentrange);
    309     LOG(0, "RESULT: WorstFragmentMap " << WorseFragmentMap << ".");
    310 
    311     // print fitted potentials
    312     std::stringstream potentials;
    313     PotentialSerializer serialize(potentials);
    314     serialize();
    315     LOG(1, "STATUS: Resulting parameters are " << std::endl << potentials.str());
    316     std::ofstream returnstream(params.potential_file.get().string().c_str());
    317     if (returnstream.good()) {
    318       returnstream << potentials.str();
    319     }
     143  // training
     144  PotentialTrainer trainer;
     145  const bool status = trainer(
     146      homologies,
     147      graph,
     148      params.training_file.get(),
     149      params.threshold.get(),
     150      params.best_of_howmany.get());
     151  if (!status) {
     152    STATUS("No required parameter derivatives for a box constraint minimization known.");
     153    return Action::failure;
    320154  }
    321   delete model;
    322155
    323156  return Action::success;
Note: See TracChangeset for help on using the changeset viewer.