adevs
adevs_rk_45.h
1 
31 #ifndef _adevs_rk_45_h_
32 #define _adevs_rk_45_h_
33 #include "adevs_hybrid.h"
34 #include <cmath>
35 
36 namespace adevs
37 {
38 
43 template <typename X> class rk_45:
44  public ode_solver<X>
45 {
46  public:
52  rk_45(ode_system<X>* sys, double err_tol, double h_max);
54  ~rk_45();
55  double integrate(double* q, double h_lim);
56  void advance(double* q, double h);
57  private:
58  double *dq, // derivative
59  *qq, // trial solution
60  *t, // temporary variables for computing stages
61  *k[6]; // the six RK stages
62  const double err_tol; // Error tolerance
63  const double h_max; // Maximum time step
64  double h_cur; // Previous successful step size
65  // Compute a trial step of size h, store the result in qq, and return the error
66  double trial_step(double h);
67 };
68 
69 template <typename X>
70 rk_45<X>::rk_45(ode_system<X>* sys, double err_tol, double h_max):
71  ode_solver<X>(sys),err_tol(err_tol),h_max(h_max),h_cur(h_max)
72 {
73  for (int i = 0; i < 6; i++)
74  k[i] = new double[sys->numVars()];
75  dq = new double[sys->numVars()];
76  qq = new double[sys->numVars()];
77  t = new double[sys->numVars()];
78 }
79 
80 template <typename X>
82 {
83  delete [] dq;
84  delete [] t;
85  for (int i = 0; i < 6; i++)
86  delete [] k[i];
87 }
88 
89 template <typename X>
90 void rk_45<X>::advance(double* q, double h)
91 {
92  double dt;
93  while ((dt = integrate(q,h)) < h) h -= dt;
94 }
95 
96 template <typename X>
97 double rk_45<X>::integrate(double* q, double h_lim)
98 {
99  // Initial error estimate and step size
100  double err = DBL_MAX, h = std::min(h_cur*1.1,std::min(h_max,h_lim));
101  for (;;) {
102  // Copy q to the trial vector
103  for (int i = 0; i < this->sys->numVars(); i++) qq[i] = q[i];
104  // Make the trial step which will be stored in qq
105  err = trial_step(h);
106  // If the error is ok, then we have found the proper step size
107  if (err <= err_tol) {
108  if (h_cur <= h_lim) h_cur = h;
109  break;
110  }
111  // Otherwise shrink the step size and try again
112  else {
113  double h_guess = 0.8*pow(err_tol*pow(h,4.0)/fabs(err),0.25);
114  if (h < h_guess) h *= 0.8;
115  else h = h_guess;
116  }
117  }
118  // Copy the trial solution to q and return the step size that was selected
119  for (int i = 0; i < this->sys->numVars(); i++) q[i] = qq[i];
120  return h;
121 }
122 
123 template <typename X>
124 double rk_45<X>::trial_step(double step)
125 {
126  // Compute k1
127  this->sys->der_func(qq,dq);
128  for (int j = 0; j < this->sys->numVars(); j++) k[0][j] = step*dq[j];
129  // Compute k2
130  for (int j = 0; j < this->sys->numVars(); j++) t[j] = qq[j] + 0.5*k[0][j];
131  this->sys->der_func(t,dq);
132  for (int j = 0; j < this->sys->numVars(); j++) k[1][j] = step*dq[j];
133  // Compute k3
134  for (int j = 0; j < this->sys->numVars(); j++) t[j] = qq[j] + 0.25*(k[0][j]+k[1][j]);
135  this->sys->der_func(t,dq);
136  for (int j = 0; j < this->sys->numVars(); j++) k[2][j] = step*dq[j];
137  // Compute k4
138  for (int j = 0; j < this->sys->numVars(); j++) t[j] = qq[j] - k[1][j] + 2.0*k[2][j];
139  this->sys->der_func(t,dq);
140  for (int j = 0; j < this->sys->numVars(); j++) k[3][j] = step*dq[j];
141  // Compute k5
142  for (int j = 0; j < this->sys->numVars(); j++)
143  t[j] = qq[j] + (7.0/27.0)*k[0][j] + (10.0/27.0)*k[1][j] + (1.0/27.0)*k[3][j];
144  this->sys->der_func(t,dq);
145  for (int j = 0; j < this->sys->numVars(); j++) k[4][j] = step*dq[j];
146  // Compute k6
147  for (int j = 0; j < this->sys->numVars(); j++)
148  t[j] = qq[j] + (28.0/625.0)*k[0][j] - 0.2*k[1][j] + (546.0/625.0)*k[2][j]
149  + (54.0/625.0)*k[3][j] - (378.0/625.0)*k[4][j];
150  this->sys->der_func(t,dq);
151  for (int j = 0 ; j < this->sys->numVars(); j++) k[5][j] = step*dq[j];
152  // Compute next state and the approximate error
153  double err = 0.0;
154  for (int j = 0; j < this->sys->numVars(); j++)
155  {
156  // Next state
157  qq[j] += (1.0/24.0)*k[0][j] + (5.0/48.0)*k[3][j] +
158  (27.0/56.0)*k[4][j] + (125.0/336.0)*k[5][j];
159  // Componennt wise maximum of the approximate error
160  err = std::max(err,
161  fabs(k[0][j]/8.0+2.0*k[2][j]/3.0+k[3][j]/16.0-27.0*k[4][j]/56.0
162  -125.0*k[5][j]/336.0));
163  }
164  // Return the error
165  return err;
166 }
167 
168 } // end of namespace
169 #endif
170 
rk_45(ode_system< X > *sys, double err_tol, double h_max)
Definition: adevs_rk_45.h:70
Definition: adevs_rk_45.h:43
~rk_45()
Destructor.
Definition: adevs_rk_45.h:81
int numVars() const
Get the number of state variables.
Definition: adevs_hybrid.h:52
double integrate(double *q, double h_lim)
Definition: adevs_rk_45.h:97
Definition: adevs_hybrid.h:45
void advance(double *q, double h)
Definition: adevs_rk_45.h:90
Definition: adevs_hybrid.h:367