hops
TruncatedNormalDistribution.hpp
Go to the documentation of this file.
1 #ifndef HOPS_TRUNCATEDNORMALDISTRIBUTION_HPP
2 #define HOPS_TRUNCATEDNORMALDISTRIBUTION_HPP
3 
4 #include <random>
5 
6 namespace hops {
7 
12  template<typename RealType>
14  public:
15  struct param_type {
16  RealType m_sigma;
17  RealType m_lowerBound = -std::numeric_limits<RealType>::infinity();
18  RealType m_upperBound = std::numeric_limits<RealType>::infinity();
19  RealType m_phiLower;
20  RealType m_phiUpper;
21 
22  void setPhi() {
23  if (m_lowerBound != -std::numeric_limits<RealType>::infinity())
25  else
26  m_phiLower = 0;
27 
28  if (m_upperBound != std::numeric_limits<RealType>::infinity())
30  else
31  m_phiUpper = 1;
32  }
33 
34  param_type(RealType m_sigma, RealType m_lowerBound, RealType m_upperBound) :
36  setPhi();
37  }
38  };
39 
40  template<typename Generator>
41  RealType operator()(Generator &g, const param_type &params) {
42  RealType uniformNumber = uniformRealDistribution.operator()(g);
43  return inverseCumulativeDensityFunction(uniformNumber, params);
44  }
45 
46  RealType inverseNormalization(const param_type &params) {
47  return params.m_phiUpper - params.m_phiLower;
48  }
49 
50  RealType probabilityDensity(RealType x, RealType m_sigma, RealType m_lowerBound, RealType m_upperBound){
51  RealType pdf = 1./(m_sigma * sqrt_2pi) * std::exp(-(1./2)*std::pow(x/m_sigma, 2));
52  return pdf / inverseNormalization(param_type(m_sigma, m_lowerBound, m_upperBound));
53  }
54 
55  private:
56  RealType inverseCumulativeDensityFunction(RealType x, const param_type &params) const {
57  x *= params.m_phiUpper - params.m_phiLower;
58  x += params.m_phiLower;
59  return inv_Phi(x) * params.m_sigma;
60  }
61 
62  std::uniform_real_distribution<> uniformRealDistribution{0, 1};
63 
64  static const constexpr RealType one_over_sqrt_2pi = RealType(0.398942280401432677939946);
65  static const constexpr RealType sqrt_2pi = RealType(2.50662827463100050241577);
66  static const constexpr RealType one_over_sqrt_2 = RealType(0.707106781186547524400845);
67 
68  // following code is adapted from https://github.com/rabauke/trng4
69 
70  static RealType Phi(RealType x) {
71  return 0.5 + 0.5 * std::erf(one_over_sqrt_2 * x);
72  }
73 
74  // this function is based on an approximation by Peter J. Acklam
75  // see http://home.online.no/~pjacklam/notes/invnorm/ for details
76 
77  struct inv_Phi_traits {
78  static RealType a(int i) throw() {
79  const RealType a_[] = {
80  -3.969683028665376e+01, 2.209460984245205e+02,
81  -2.759285104469687e+02, 1.383577518672690e+02,
82  -3.066479806614716e+01, 2.506628277459239e+00};
83  return a_[i];
84  }
85 
86  static RealType b(int i) throw() {
87  const RealType b_[] = {
88  -5.447609879822406e+01, 1.615858368580409e+02,
89  -1.556989798598866e+02, 6.680131188771972e+01,
90  -1.328068155288572e+01};
91  return b_[i];
92  }
93 
94  static RealType c(int i) throw() {
95  const RealType c_[] = {
96  -7.784894002430293e-03, -3.223964580411365e-01,
97  -2.400758277161838e+00, -2.549732539343734e+00,
98  4.374664141464968e+00, 2.938163982698783e+00};
99  return c_[i];
100  }
101 
102  static RealType d(int i) throw() {
103  const RealType d_[] = {
104  7.784695709041462e-03, 3.224671290700398e-01,
105  2.445134137142996e+00, 3.754408661907416e+00};
106  return d_[i];
107  }
108 
109  static RealType x_low() throw() { return 0.02425; }
110 
111  static RealType x_high() throw() { return 1.0 - 0.02425; }
112 
113  static RealType zero() throw() { return 0.0; }
114 
115  static RealType one() throw() { return 1.0; }
116 
117  static RealType one_half() throw() { return 0.5; }
118 
119  static RealType minus_two() throw() { return -2.0; }
120  };
121 
122  static RealType inv_Phi(RealType x) {
123  if (x < inv_Phi_traits::zero() || x > inv_Phi_traits::one())
124  return std::numeric_limits<RealType>::quiet_NaN();
125  if (x == inv_Phi_traits::zero())
126  return -std::numeric_limits<RealType>::infinity();
127  if (x == inv_Phi_traits::one())
128  return std::numeric_limits<RealType>::infinity();
129 
130  RealType t, q;
131  if (x < inv_Phi_traits::x_low()) {
132  // Rational approximation for lower region
133  q = std::sqrt(inv_Phi_traits::minus_two() * std::log(x));
134  t = (((((inv_Phi_traits::c(0) * q + inv_Phi_traits::c(1)) * q +
135  inv_Phi_traits::c(2)) * q + inv_Phi_traits::c(3)) * q +
136  inv_Phi_traits::c(4)) * q + inv_Phi_traits::c(5)) /
137  ((((inv_Phi_traits::d(0) * q + inv_Phi_traits::d(1)) * q +
138  inv_Phi_traits::d(2)) * q + inv_Phi_traits::d(3)) * q +
139  inv_Phi_traits::one());
140  } else if (x < inv_Phi_traits::x_high()) {
141  // Rational approximation for central region
142  q = x - inv_Phi_traits::one_half();
143  RealType r = q * q;
144  t = (((((inv_Phi_traits::a(0) * r + inv_Phi_traits::a(1)) * r +
145  inv_Phi_traits::a(2)) * r + inv_Phi_traits::a(3)) * r +
146  inv_Phi_traits::a(4)) * r + inv_Phi_traits::a(5)) * q /
147  (((((inv_Phi_traits::b(0) * r + inv_Phi_traits::b(1)) * r +
148  inv_Phi_traits::b(2)) * r + inv_Phi_traits::b(3)) * r +
149  inv_Phi_traits::b(4)) * r + inv_Phi_traits::one());
150  } else {
151  // Rational approximation for upper region
152  q = std::sqrt(inv_Phi_traits::minus_two() * std::log(1.0 - x));
153  t = -(((((inv_Phi_traits::c(0) * q + inv_Phi_traits::c(1)) * q +
154  inv_Phi_traits::c(2)) * q + inv_Phi_traits::c(3)) * q +
155  inv_Phi_traits::c(4)) * q + inv_Phi_traits::c(5)) /
156  ((((inv_Phi_traits::d(0) * q + inv_Phi_traits::d(1)) * q +
157  inv_Phi_traits::d(2)) * q + inv_Phi_traits::d(3)) * q +
158  inv_Phi_traits::one());
159  }
160 
161  // refinement by Halley rational method
162  if (std::numeric_limits<RealType>::epsilon() < 1e-9) {
163  RealType e(Phi(t) - x);
164  RealType u(e * sqrt_2pi * std::exp(t * t * inv_Phi_traits::one_half()));
165  t -= u / (inv_Phi_traits::one() + t * u * inv_Phi_traits::one_half());
166  }
167  return t;
168  }
169  };
170 }
171 
172 #endif //HOPS_TRUNCATEDNORMALDISTRIBUTION_HPP
hops::TruncatedNormalDistribution::param_type::m_sigma
RealType m_sigma
Definition: TruncatedNormalDistribution.hpp:16
hops::TruncatedNormalDistribution::param_type::setPhi
void setPhi()
Definition: TruncatedNormalDistribution.hpp:22
hops::TruncatedNormalDistribution::param_type
Definition: TruncatedNormalDistribution.hpp:15
hops
Definition: CsvReader.hpp:8
hops::TruncatedNormalDistribution::operator()
RealType operator()(Generator &g, const param_type &params)
Definition: TruncatedNormalDistribution.hpp:41
hops::TruncatedNormalDistribution::param_type::m_phiUpper
RealType m_phiUpper
Definition: TruncatedNormalDistribution.hpp:20
hops::TruncatedNormalDistribution::param_type::m_upperBound
RealType m_upperBound
Definition: TruncatedNormalDistribution.hpp:18
hops::TruncatedNormalDistribution
Truncated normal distribution with mean 0.
Definition: TruncatedNormalDistribution.hpp:13
hops::TruncatedNormalDistribution::param_type::m_lowerBound
RealType m_lowerBound
Definition: TruncatedNormalDistribution.hpp:17
hops::TruncatedNormalDistribution::probabilityDensity
RealType probabilityDensity(RealType x, RealType m_sigma, RealType m_lowerBound, RealType m_upperBound)
Definition: TruncatedNormalDistribution.hpp:50
hops::TruncatedNormalDistribution::param_type::param_type
param_type(RealType m_sigma, RealType m_lowerBound, RealType m_upperBound)
Definition: TruncatedNormalDistribution.hpp:34
hops::TruncatedNormalDistribution::inverseNormalization
RealType inverseNormalization(const param_type &params)
Definition: TruncatedNormalDistribution.hpp:46
hops::TruncatedNormalDistribution::param_type::m_phiLower
RealType m_phiLower
Definition: TruncatedNormalDistribution.hpp:19