hops
CSmMALAProposal.hpp
Go to the documentation of this file.
1 #ifndef HOPS_CSMMALA_HPP
2 #define HOPS_CSMMALA_HPP
3 
4 #include <Eigen/Eigenvalues>
7 #include <random>
8 #include <utility>
9 
10 namespace hops {
11  namespace CSmMALAProposalDetails {
12  template<typename MatrixType>
14  MatrixType &sqrtInvMetric,
15  double &logSqrtDeterminant) {
16  Eigen::BDCSVD<MatrixType> solver(metric, Eigen::ComputeFullU);
17  sqrtInvMetric = solver.matrixU() * solver.singularValues().cwiseInverse().cwiseSqrt().asDiagonal() *
18  solver.matrixU().adjoint();
19  logSqrtDeterminant = 0.5 * solver.singularValues().array().log().sum();
20  }
21  }
22 
23  template<typename ModelType, typename InternalMatrixType>
24  class CSmMALAProposal : public ModelType {
25  public:
32  CSmMALAProposal(ModelType model, InternalMatrixType A, VectorType b, const VectorType& currentState);
33 
34  void propose(RandomNumberGenerator &randomNumberGenerator);
35 
36  void acceptProposal();
37 
38  [[nodiscard]] typename MatrixType::Scalar computeLogAcceptanceProbability();
39 
40  VectorType getState() const;
41 
42  void setState(VectorType newState);
43 
44  VectorType getProposal() const;
45 
46  typename MatrixType::Scalar getStepSize() const;
47 
48  void setStepSize(typename MatrixType::Scalar newStepSize);
49 
50  void setFisherWeight(typename MatrixType::Scalar newFisherWeight);
51 
53 
55 
56  private:
57  VectorType computeTruncatedGradient(VectorType x);
58 
59  InternalMatrixType A;
60  VectorType b;
61 
62  VectorType state;
63  VectorType driftedState;
64  VectorType m_proposal;
65  VectorType driftedProposal;
66  MatrixType::Scalar stateLogSqrtDeterminant = 0;
67  MatrixType::Scalar proposalLogSqrtDeterminant = 0;
68  MatrixType::Scalar stateNegativeLogLikelihood = 0;
69  MatrixType::Scalar proposalNegativeLogLikelihood = 0;
70  MatrixType stateSqrtInvMetric;
71  MatrixType stateMetric;
72  MatrixType proposalSqrtInvMetric;
73  MatrixType proposalMetric;
74 
75  MatrixType::Scalar stepSize = 1;
76  MatrixType::Scalar fisherWeight = .5;
77  MatrixType::Scalar fisherScale = 1.;
78  MatrixType::Scalar geometricFactor = 0;
79  MatrixType::Scalar covarianceFactor = 0;
80 
81  std::normal_distribution<MatrixType::Scalar> normalDistribution{0., 1.};
83  };
84 
85  template<typename ModelType, typename InternalMatrixType>
87  InternalMatrixType A,
89  const VectorType& currentState) :
90  ModelType(std::move(model)),
91  A(std::move(A)),
92  b(std::move(b)),
93  dikinEllipsoidCalculator(this->A, this->b) {
94  stateMetric = Eigen::Matrix<typename MatrixType::Scalar, Eigen::Dynamic, Eigen::Dynamic>::Zero(
95  currentState.rows(), currentState.rows());
96  proposalMetric = Eigen::Matrix<typename MatrixType::Scalar, Eigen::Dynamic, Eigen::Dynamic>::Zero(
97  currentState.rows(), currentState.rows());
98  setState(currentState);
99  setStepSize(1.);
100  m_proposal = state;
101  }
102 
103  template<typename ModelType, typename InternalMatrixType>
105  RandomNumberGenerator &randomNumberGenerator) {
106  for (long i = 0; i < m_proposal.rows(); ++i) {
107  m_proposal(i) = normalDistribution(randomNumberGenerator);
108  }
109  m_proposal = driftedState + covarianceFactor * (stateSqrtInvMetric * m_proposal);
110  }
111 
112  template<typename ModelType, typename InternalMatrixType>
114  state.swap(m_proposal);
115  driftedState.swap(driftedProposal);
116  stateSqrtInvMetric.swap(proposalSqrtInvMetric);
117  stateMetric.swap(proposalMetric);
118  stateLogSqrtDeterminant = proposalLogSqrtDeterminant;
119  stateNegativeLogLikelihood = proposalNegativeLogLikelihood;
120  }
121 
122  template<typename ModelType, typename InternalMatrixType>
123  typename MatrixType::Scalar
125  bool isProposalInteriorPoint = ((A * m_proposal - b).array() < 0).all();
126  if (!isProposalInteriorPoint) {
127  return -std::numeric_limits<typename MatrixType::Scalar>::infinity();
128  }
129 
130  // Important: compute gradient before fisher info or else x3cflux2 will throw
131  VectorType gradient = computeTruncatedGradient(m_proposal);
132  proposalMetric.setZero();
133  if (fisherWeight != 0) {
134  auto optionalFisherInformation = ModelType::computeExpectedFisherInformation(m_proposal);
135  if(optionalFisherInformation) {
136  auto fisherInformation = optionalFisherInformation.value();
137  proposalMetric += (fisherWeight * fisherScale * fisherInformation);
138  }
139  }
140  if (fisherWeight != 1) {
141  auto dikinEllipsoid = dikinEllipsoidCalculator.computeDikinEllipsoid(m_proposal);
142  proposalMetric += (1 - fisherWeight) * dikinEllipsoid;
143 
144  }
145  CSmMALAProposalDetails::computeMetricInfoForCSmMALAWithSvd(proposalMetric, proposalSqrtInvMetric,
146  proposalLogSqrtDeterminant);
147  driftedProposal = m_proposal +
148  0.5 * std::pow(covarianceFactor, 2) * proposalSqrtInvMetric * proposalSqrtInvMetric *
149  gradient;
150  proposalNegativeLogLikelihood = ModelType::computeNegativeLogLikelihood(m_proposal);
151 
152  double normDifference =
153  static_cast<double>((driftedState - m_proposal).transpose() * stateMetric * (driftedState - m_proposal)) -
154  static_cast<double>((state - driftedProposal).transpose() * proposalMetric * (state - driftedProposal));
155 
156  return -proposalNegativeLogLikelihood
157  + stateNegativeLogLikelihood
158  + proposalLogSqrtDeterminant
159  - stateLogSqrtDeterminant
160  + geometricFactor * normDifference;
161  }
162 
163  template<typename ModelType, typename InternalMatrixType>
165  return state;
166  }
167 
168  template<typename ModelType, typename InternalMatrixType>
170  state.swap(newState);
171  // Important: compute gradient before fisher info or else x3cflux2 will throw
172  VectorType gradient = computeTruncatedGradient(state);
173  stateMetric.setZero();
174  if (fisherWeight != 0) {
175  auto optionalFisherInformation = ModelType::computeExpectedFisherInformation(m_proposal);
176  if(optionalFisherInformation) {
177  auto fisherInformation = optionalFisherInformation.value();
178  proposalMetric += fisherWeight * fisherScale * fisherInformation;
179  }
180  }
181  if (fisherWeight != 1) {
182  auto dikinEllipsoid = dikinEllipsoidCalculator.computeDikinEllipsoid(state);
183  stateMetric += (1 - fisherWeight) * dikinEllipsoid;
184  }
186  stateSqrtInvMetric,
187  stateLogSqrtDeterminant);
188  driftedState = state + 0.5 * std::pow(covarianceFactor, 2) * stateSqrtInvMetric * stateSqrtInvMetric *
189  gradient;
190  stateNegativeLogLikelihood = ModelType::computeNegativeLogLikelihood(state);
191  }
192 
193  template<typename ModelType, typename InternalMatrixType>
195  return m_proposal;
196  }
197 
198  template<typename ModelType, typename InternalMatrixType>
200  return stepSize;
201  }
202 
203  template<typename ModelType, typename InternalMatrixType>
204  void CSmMALAProposal<ModelType, InternalMatrixType>::setStepSize(typename MatrixType::Scalar newStepSize) {
205  stepSize = newStepSize;
206  geometricFactor = A.cols() / (2 * stepSize * stepSize);
207  covarianceFactor = stepSize / std::sqrt(A.cols());
208  setState(state);
209  }
210 
211  template<typename ModelType, typename InternalMatrixType>
213  typename MatrixType::Scalar newFisherWeight) {
214  if (fisherWeight > 1 || fisherWeight < 0) {
215  throw std::runtime_error("fisherWeight should be in [0, 1].");
216  }
217  fisherWeight = newFisherWeight;
218  setState(state);
219  }
220 
221  template<typename ModelType, typename InternalMatrixType>
223  return stateNegativeLogLikelihood;
224  }
225 
226  template<typename ModelType, typename InternalMatrixType>
228  return "CSmMALA";
229  }
230 
231  template<typename ModelType, typename InternalMatrixType>
233  auto gradient = ModelType::computeLogLikelihoodGradient(x);
234  if(gradient) {
235  double norm = gradient.value().norm();
236  if (norm != 0) {
237  gradient.value() /= norm;
238  }
239  return gradient.value();
240  }
241  return VectorType::Zero(x.rows());
242  }
243 }
244 
245 #endif //HOPS_CSMMALA_HPP
hops::CSmMALAProposal::getNegativeLogLikelihoodOfCurrentState
double getNegativeLogLikelihoodOfCurrentState()
Definition: CSmMALAProposal.hpp:222
hops::MatrixType
Eigen::MatrixXd MatrixType
Definition: MatrixType.hpp:7
hops::CSmMALAProposal::getState
VectorType getState() const
Definition: CSmMALAProposal.hpp:164
IsAddMessageAvailabe.hpp
pcg_detail::engine
Definition: pcg_random.hpp:364
hops::CSmMALAProposal::getStepSize
MatrixType::Scalar getStepSize() const
Definition: CSmMALAProposal.hpp:199
hops::CSmMALAProposal::CSmMALAProposal
CSmMALAProposal(ModelType model, InternalMatrixType A, VectorType b, const VectorType &currentState)
Constructs m_proposal mechanism on polytope defined as Ax<b.
Definition: CSmMALAProposal.hpp:86
DikinProposal.hpp
hops::CSmMALAProposalDetails::computeMetricInfoForCSmMALAWithSvd
void computeMetricInfoForCSmMALAWithSvd(const MatrixType &metric, MatrixType &sqrtInvMetric, double &logSqrtDeterminant)
Definition: CSmMALAProposal.hpp:13
hops::CSmMALAProposal::acceptProposal
void acceptProposal()
Definition: CSmMALAProposal.hpp:113
hops
Definition: CsvReader.hpp:8
hops::DikinEllipsoidCalculator
Definition: DikinEllipsoidCalculator.hpp:11
hops::CSmMALAProposal::getProposal
VectorType getProposal() const
Definition: CSmMALAProposal.hpp:194
hops::CSmMALAProposal::setState
void setState(VectorType newState)
Definition: CSmMALAProposal.hpp:169
string
NAME string(REPLACE ".cpp" "_bin" example_name ${example_filename}) if($
Definition: hops/Third-party/HighFive/src/examples/CMakeLists.txt:6
hops::CSmMALAProposal::setFisherWeight
void setFisherWeight(typename MatrixType::Scalar newFisherWeight)
Definition: CSmMALAProposal.hpp:212
hops::CSmMALAProposal::setStepSize
void setStepSize(typename MatrixType::Scalar newStepSize)
Definition: CSmMALAProposal.hpp:204
hops::CSmMALAProposal
Definition: CSmMALAProposal.hpp:24
hops::VectorType
Eigen::VectorXd VectorType
Definition: VectorType.hpp:7
hops::CSmMALAProposal::computeLogAcceptanceProbability
MatrixType::Scalar computeLogAcceptanceProbability()
Definition: CSmMALAProposal.hpp:124
hops::CSmMALAProposal::getName
std::string getName()
Definition: CSmMALAProposal.hpp:227
hops::CSmMALAProposal::propose
void propose(RandomNumberGenerator &randomNumberGenerator)
Definition: CSmMALAProposal.hpp:104