- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
src/Actions/PotentialAction/FitPotentialAction.cpp
r707a2b r0ea063 55 55 #include "Fragmentation/Homology/HomologyGraph.hpp" 56 56 #include "Fragmentation/Summation/SetValues/Fragment.hpp" 57 #include "FunctionApproximation/Extractors.hpp" 58 #include "FunctionApproximation/FunctionApproximation.hpp" 59 #include "FunctionApproximation/FunctionModel.hpp" 60 #include "FunctionApproximation/TrainingData.hpp" 61 #include "FunctionApproximation/writeDistanceEnergyTable.hpp" 62 #include "Potentials/CompoundPotential.hpp" 63 #include "Potentials/Exceptions.hpp" 64 #include "Potentials/PotentialDeserializer.hpp" 57 #include "Potentials/EmpiricalPotential.hpp" 65 58 #include "Potentials/PotentialFactory.hpp" 66 59 #include "Potentials/PotentialRegistry.hpp" 67 60 #include "Potentials/PotentialSerializer.hpp" 61 #include "Potentials/PotentialTrainer.hpp" 68 62 #include "Potentials/SerializablePotential.hpp" 63 #include "World.hpp" 69 64 70 65 using namespace MoleCuilder; … … 75 70 /** =========== define the function ====================== */ 76 71 77 HomologyGraph getFirstGraphwithSpecifiedElements(78 const HomologyContainer &homologies,79 const SerializablePotential::ParticleTypes_t &types)80 {81 ASSERT( !types.empty(),82 "getFirstGraphwithSpecifiedElements() - charges is empty?");83 // create charges84 Fragment::charges_t charges;85 charges.resize(types.size());86 std::transform(types.begin(), types.end(),87 charges.begin(), boost::lambda::_1);88 // convert into count map89 Extractors::elementcounts_t counts_per_charge =90 Extractors::_detail::getElementCounts(charges);91 ASSERT( !counts_per_charge.empty(),92 "getFirstGraphwithSpecifiedElements() - charge counts are empty?");93 LOG(2, "DEBUG: counts_per_charge is " << counts_per_charge << ".");94 // we want to check each (unique) key only once95 for (HomologyContainer::const_key_iterator iter = homologies.key_begin();96 iter != homologies.key_end(); iter = homologies.getNextKey(iter)) {97 // check if every element has the right number of counts98 Extractors::elementcounts_t::const_iterator countiter = counts_per_charge.begin();99 for (; countiter != counts_per_charge.end(); ++countiter)100 if (!(*iter).hasTimesAtomicNumber(101 static_cast<size_t>(countiter->first),102 static_cast<size_t>(countiter->second))103 )104 break;105 if( countiter == counts_per_charge.end())106 return *iter;107 }108 return HomologyGraph();109 }110 111 SerializablePotential::ParticleTypes_t getNumbersFromElements(112 const std::vector<const element *> &fragment)113 {114 SerializablePotential::ParticleTypes_t fragmentnumbers;115 std::transform(fragment.begin(), fragment.end(), std::back_inserter(fragmentnumbers),116 boost::bind(&element::getAtomicNumber, _1));117 return fragmentnumbers;118 }119 120 121 72 ActionState::ptr PotentialFitPotentialAction::performCall() { 122 73 // fragment specifies the homology fragment to use 123 74 SerializablePotential::ParticleTypes_t fragmentnumbers = 124 getNumbersFromElements(params.fragment.get());75 PotentialTrainer::getNumbersFromElements(params.fragment.get()); 125 76 126 77 // either charges and a potential is specified or a file 127 if (boost::filesystem::exists(params.potential_file.get())) { 128 std::ifstream returnstream(params.potential_file.get().string().c_str()); 129 if (returnstream.good()) { 130 try { 131 PotentialDeserializer deserialize(returnstream); 132 deserialize(); 133 } catch (SerializablePotentialMissingValueException &e) { 134 if (const std::string *key = boost::get_error_info<SerializablePotentialKey>(e)) 135 STATUS("Missing value when parsing information for potential "+*key+"."); 136 else 137 STATUS("Missing value parsing information for potential with unknown key."); 138 return Action::failure; 139 } catch (SerializablePotentialIllegalKeyException &e) { 140 if (const std::string *key = boost::get_error_info<SerializablePotentialKey>(e)) 141 STATUS("Illegal key parsing information for potential "+*key+"."); 142 else 143 STATUS("Illegal key parsing information for potential with unknown key."); 144 return Action::failure; 145 } 146 } else { 147 STATUS("Failed to parse from "+params.potential_file.get().string()+"."); 148 return Action::failure; 78 if (params.charges.get().empty()) { 79 STATUS("No charges given!"); 80 return Action::failure; 81 } else { 82 // charges specify the potential type 83 SerializablePotential::ParticleTypes_t chargenumbers = 84 PotentialTrainer::getNumbersFromElements(params.charges.get()); 85 86 LOG(0, "STATUS: I'm training now a " << params.potentialtype.get() 87 << " potential on charges " << chargenumbers << " on data from World's homologies."); 88 89 // register desired potential and an additional constant one 90 { 91 EmpiricalPotential *potential = 92 PotentialFactory::getInstance().createInstance( 93 params.potentialtype.get(), 94 chargenumbers); 95 // check whether such a potential already exists 96 const std::string potential_name = potential->getName(); 97 if (PotentialRegistry::getInstance().isPresentByName(potential_name)) { 98 delete potential; 99 potential = PotentialRegistry::getInstance().getByName(potential_name); 100 } else 101 PotentialRegistry::getInstance().registerInstance(potential); 149 102 } 150 returnstream.close(); 151 152 LOG(0, "STATUS: I'm training now a set of potentials parsed from " 153 << params.potential_file.get().string() << " on a fragment " 154 << fragmentnumbers << " on data from World's homologies."); 155 156 } else { 157 if (params.charges.get().empty()) { 158 STATUS("Neither charges nor potential file given!"); 159 return Action::failure; 160 } else { 161 // charges specify the potential type 162 SerializablePotential::ParticleTypes_t chargenumbers = 163 getNumbersFromElements(params.charges.get()); 164 165 LOG(0, "STATUS: I'm training now a " << params.potentialtype.get() 166 << " potential on charges " << chargenumbers << " on data from World's homologies."); 167 168 // register desired potential and an additional constant one 169 { 170 EmpiricalPotential *potential = 171 PotentialFactory::getInstance().createInstance( 172 params.potentialtype.get(), 173 chargenumbers); 174 // check whether such a potential already exists 175 const std::string potential_name = potential->getName(); 176 if (PotentialRegistry::getInstance().isPresentByName(potential_name)) { 177 delete potential; 178 potential = PotentialRegistry::getInstance().getByName(potential_name); 179 } else 180 PotentialRegistry::getInstance().registerInstance(potential); 181 } 182 { 183 EmpiricalPotential *constant = 184 PotentialFactory::getInstance().createInstance( 185 std::string("constant"), 186 SerializablePotential::ParticleTypes_t()); 187 // check whether such a potential already exists 188 const std::string constant_name = constant->getName(); 189 if (PotentialRegistry::getInstance().isPresentByName(constant_name)) { 190 delete constant; 191 constant = PotentialRegistry::getInstance().getByName(constant_name); 192 } else 193 PotentialRegistry::getInstance().registerInstance(constant); 194 } 103 { 104 EmpiricalPotential *constant = 105 PotentialFactory::getInstance().createInstance( 106 std::string("constant"), 107 SerializablePotential::ParticleTypes_t()); 108 // check whether such a potential already exists 109 const std::string constant_name = constant->getName(); 110 if (PotentialRegistry::getInstance().isPresentByName(constant_name)) { 111 delete constant; 112 constant = PotentialRegistry::getInstance().getByName(constant_name); 113 } else 114 PotentialRegistry::getInstance().registerInstance(constant); 195 115 } 196 116 } 197 117 198 118 // parse homologies into container 199 HomologyContainer &homologies = World::getInstance().getHomologies();119 const HomologyContainer &homologies = World::getInstance().getHomologies(); 200 120 201 121 // first we try to look into the HomologyContainer … … 211 131 212 132 // then we ought to pick the right HomologyGraph ... 213 const HomologyGraph graph = getFirstGraphwithSpecifiedElements(homologies,fragmentnumbers); 133 const HomologyGraph graph = 134 PotentialTrainer::getFirstGraphwithSpecifiedElements(homologies,fragmentnumbers); 214 135 if (graph != HomologyGraph()) { 215 136 LOG(1, "First representative graph containing fragment " … … 220 141 } 221 142 222 // fit potential 223 FunctionModel *model = new CompoundPotential(graph); 224 ASSERT( model != NULL, 225 "PotentialFitPotentialAction::performCall() - model is NULL."); 226 227 /******************** TRAINING ********************/ 228 // fit potential 229 FunctionModel::parameters_t bestparams(model->getParameterDimension(), 0.); 230 { 231 // Afterwards we go through all of this type and gather the distance and the energy value 232 TrainingData data(model->getSpecificFilter()); 233 data(homologies.getHomologousGraphs(graph)); 234 235 // print distances and energies if desired for debugging 236 if (!data.getTrainingInputs().empty()) { 237 // print which distance is which 238 size_t counter=1; 239 if (DoLog(3)) { 240 const FunctionModel::arguments_t &inputs = data.getAllArguments()[0]; 241 for (FunctionModel::arguments_t::const_iterator iter = inputs.begin(); 242 iter != inputs.end(); ++iter) { 243 const argument_t &arg = *iter; 244 LOG(3, "DEBUG: distance " << counter++ << " is between (#" 245 << arg.indices.first << "c" << arg.types.first << "," 246 << arg.indices.second << "c" << arg.types.second << ")."); 247 } 248 } 249 250 // print table 251 if (params.training_file.get().string().empty()) { 252 LOG(3, "DEBUG: I gathered the following training data:\n" << 253 _detail::writeDistanceEnergyTable(data.getDistanceEnergyTable())); 254 } else { 255 std::ofstream trainingstream(params.training_file.get().string().c_str()); 256 if (trainingstream.good()) { 257 LOG(3, "DEBUG: Writing training data to file " << 258 params.training_file.get().string() << "."); 259 trainingstream << _detail::writeDistanceEnergyTable(data.getDistanceEnergyTable()); 260 } 261 trainingstream.close(); 262 } 263 } 264 265 if ((params.threshold.get() < 1) && (params.best_of_howmany.isSet())) 266 ELOG(2, "threshold parameter always overrules max_runs, both are specified."); 267 // now perform the function approximation by optimizing the model function 268 FunctionApproximation approximator(data, *model); 269 if (model->isBoxConstraint() && approximator.checkParameterDerivatives()) { 270 double l2error = std::numeric_limits<double>::max(); 271 // seed with current time 272 srand((unsigned)time(0)); 273 unsigned int runs=0; 274 // threshold overrules max_runs 275 const double threshold = params.threshold.get(); 276 const unsigned int max_runs = (threshold >= 1.) ? 277 (params.best_of_howmany.isSet() ? params.best_of_howmany.get() : 1) : 0; 278 LOG(1, "INFO: Maximum runs is " << max_runs << " and threshold set to " << threshold << "."); 279 do { 280 // generate new random initial parameter values 281 model->setParametersToRandomInitialValues(data); 282 LOG(1, "INFO: Initial parameters of run " << runs << " are " 283 << model->getParameters() << "."); 284 approximator(FunctionApproximation::ParameterDerivative); 285 LOG(1, "INFO: Final parameters of run " << runs << " are " 286 << model->getParameters() << "."); 287 const double new_l2error = data.getL2Error(*model); 288 if (new_l2error < l2error) { 289 // store currently best parameters 290 l2error = new_l2error; 291 bestparams = model->getParameters(); 292 LOG(1, "STATUS: New fit from run " << runs 293 << " has better error of " << l2error << "."); 294 } 295 } while (( ++runs < max_runs) || (l2error > threshold)); 296 // reset parameters from best fit 297 model->setParameters(bestparams); 298 LOG(1, "INFO: Best parameters with L2 error of " 299 << l2error << " are " << model->getParameters() << "."); 300 } else { 301 STATUS("No required parameter derivatives for a box constraint minimization known."); 302 return Action::failure; 303 } 304 305 // create a map of each fragment with error. 306 HomologyContainer::range_t fragmentrange = homologies.getHomologousGraphs(graph); 307 TrainingData::L2ErrorConfigurationIndexMap_t WorseFragmentMap = 308 data.getWorstFragmentMap(*model, fragmentrange); 309 LOG(0, "RESULT: WorstFragmentMap " << WorseFragmentMap << "."); 310 311 // print fitted potentials 312 std::stringstream potentials; 313 PotentialSerializer serialize(potentials); 314 serialize(); 315 LOG(1, "STATUS: Resulting parameters are " << std::endl << potentials.str()); 316 std::ofstream returnstream(params.potential_file.get().string().c_str()); 317 if (returnstream.good()) { 318 returnstream << potentials.str(); 319 } 143 // training 144 PotentialTrainer trainer; 145 const bool status = trainer( 146 homologies, 147 graph, 148 params.training_file.get(), 149 params.threshold.get(), 150 params.best_of_howmany.get()); 151 if (!status) { 152 STATUS("No required parameter derivatives for a box constraint minimization known."); 153 return Action::failure; 320 154 } 321 delete model;322 155 323 156 return Action::success;
Note:
See TracChangeset
for help on using the changeset viewer.