1 #ifndef HOPS_PARALLELTEMPERING_HPP
2 #define HOPS_PARALLELTEMPERING_HPP
4 #ifdef HOPS_MPI_SUPPORTED
20 template<
typename MarkovChainImpl>
21 class ParallelTempering :
public MarkovChainImpl {
25 double exchangeAttemptProbability = 0.1) :
26 MarkovChainImpl(markovChainImpl),
27 synchronizedRandomNumberGenerator(synchronizedRandomNumberGenerator),
28 exchangeAttemptProbability(exchangeAttemptProbability) {
29 if (exchangeAttemptProbability > 1) {
30 this->exchangeAttemptProbability = 1;
31 }
else if (exchangeAttemptProbability < 0) {
32 this->exchangeAttemptProbability = 0;
36 MPI_Comm_dup(MPI_COMM_WORLD, &communicator);
37 MPI_Comm_size(communicator, &numberOfChains);
40 MPI_Comm_rank(communicator, &chainIndex);
41 int largestChainIndex = numberOfChains == 1 ? 1 : numberOfChains - 1;
42 MarkovChainImpl::setColdness(1. -
static_cast<double>(chainIndex) / largestChainIndex);
46 MarkovChainImpl::draw(randomNumberGenerator);
47 executeParallelTemperingStep();
50 void writeRecordsToFile(
const FileWriter *
const fileWriter)
const {
51 if constexpr(IsWriteRecordsToFileAvailable<MarkovChainImpl>::value) {
53 MPI_Comm_rank(communicator, &chainIndex);
54 if (chainIndex == 0) {
55 MarkovChainImpl::writeRecordsToFile(fileWriter);
61 if constexpr(IsStoreRecordAvailable<MarkovChainImpl>::value) {
63 MPI_Comm_rank(communicator, &chainIndex);
64 if (chainIndex == 0) {
65 MarkovChainImpl::storeRecord();
73 bool executeParallelTemperingStep() {
74 if (shouldProposeExchange()) {
76 std::pair<int, int> chainPair = generateChainPairForExchangeProposal();
78 MPI_Comm_rank(communicator, &world_rank);
79 if (chainPair.first == world_rank || chainPair.second == world_rank) {
80 int otherChainRank = world_rank == chainPair.first ? chainPair.second : chainPair.first;
82 double acceptanceProbability = computeExchangeAcceptanceProbability(otherChainRank);
83 double chance = uniformRealDistribution(synchronizedRandomNumberGenerator);
84 if (chance <= acceptanceProbability) {
85 exchangeStates(otherChainRank);
92 uniformRealDistribution(synchronizedRandomNumberGenerator);
98 double computeExchangeAcceptanceProbability(
int otherChainRank) {
99 double coldness = this->getColdness();
100 double coldNegativeLogLikelihood = this->getNegativeLogLikelihoodOfCurrentState() / coldness;
101 double thisChainProperties[] = {
103 coldNegativeLogLikelihood
106 double otherChainProperties[2];
107 std::memcpy(otherChainProperties, thisChainProperties,
sizeof(
double) * 2);
113 double diffColdness = thisChainProperties[0] - otherChainProperties[0];
114 double diffNegativeLoglikelihoods = thisChainProperties[1] - otherChainProperties[1];
116 double acceptanceProbability = std::exp(diffColdness * diffNegativeLoglikelihoods);
117 return acceptanceProbability;
120 void exchangeStates(
int otherChainRank) {
121 VectorType thisState = MarkovChainImpl::getState();
126 MarkovChainImpl::setState(thisState);
129 std::pair<int, int> generateChainPairForExchangeProposal() {
131 int chainIndex = uniformIntDistribution(synchronizedRandomNumberGenerator,
132 std::uniform_int_distribution<int>::param_type(0,
133 numberOfChains - 2));
134 return std::make_pair(chainIndex, chainIndex + 1);
137 bool shouldProposeExchange() {
138 double chance = uniformRealDistribution(synchronizedRandomNumberGenerator);
139 return (chance < exchangeAttemptProbability);
142 double getExchangeAttemptProbability()
const {
143 return exchangeAttemptProbability;
146 void setExchangeAttemptProbability(
double newExchangeAttemptProbability) {
147 ParallelTempering::exchangeAttemptProbability = newExchangeAttemptProbability;
152 double exchangeAttemptProbability;
153 std::uniform_int_distribution<int> uniformIntDistribution;
154 std::uniform_real_distribution<double> uniformRealDistribution;
155 MPI_Comm communicator;
167 template<
typename MarkovChainImpl>
171 ) : MarkovChainImpl(markovChainImpl) {
172 throw std::runtime_error(
"MPI not supported on current platform");
175 double) : MarkovChainImpl(markovChainImpl) {
176 throw std::runtime_error(
"MPI not supported on current platform");
180 #endif //HOPS_MPI_SUPPORTED
182 #endif //HOPS_PARALLELTEMPERING_HPP