#include "std_includes.h"

enum Mode { PRIMAL, CONTEXT_FREE_JOINT_FORWARD, CONTEXT_FREE_JOINT_BACKWARD, 
            CONTEXT_SENSITIVE_JOINT };

template<typename T>
void steps(Mode mode, int from, int to, T& x, T &xa,
           const vector<T>& p, vector<T>& pa,
           const vector<double>& dW_j) {
  static stack<T> tbr_T; static stack<double> tbr_d;
  int n=p.size(); double dt=1.0/n, t=from*dt;
  switch (mode) {
    default: assert(false); break;
    case CONTEXT_FREE_JOINT_FORWARD:
      tbr_T.push(x); tbr_d.push(t);
      for (int i=from;i<to;i++) {
        x+=dt*p[i]*sin(x*t)+p[i]*cos(x*t)*sqrt(dt)*dW_j[i];
        t+=dt;
      } 
      break;
    case CONTEXT_FREE_JOINT_BACKWARD:
      t=tbr_d.top(); tbr_d.pop(); x=tbr_T.top(); tbr_T.pop();
    case CONTEXT_SENSITIVE_JOINT:
      for (int i=from;i<to;i++) {
        tbr_T.push(x);
        x+=dt*p[i]*sin(x*t)+p[i]*cos(x*t)*sqrt(dt)*dW_j[i];
        t+=dt;
      } 
      double y=x;
      for (int i=to-1;i>=from;i--) {
        t-=dt;
        x=tbr_T.top(); tbr_T.pop();
        pa[i]+=(dt*sin(x*t)+cos(x*t)*sqrt(dt)*dW_j[i])*xa;
        xa=(1+dt*p[i]*t*cos(x*t)-p[i]*t*sin(x*t)*sqrt(dt)*dW_j[i])*xa;
      }
      x=y;
  }
}

template<typename T>
void path(Mode mode, const int ncs,
  T& x, T& xa, const vector<T>& p, vector<T>& pa,
  const vector<double>& dW_j) {
  int n=dW_j.size();
  double t=0, dt=1.0/n;
  switch (mode) {
    default: assert(false); break;
    case PRIMAL:
      for (int i=0;i<n;i++) {
        x+=dt*p[i]*sin(x*t)+p[i]*cos(x*t)*sqrt(dt)*dW_j[i];
        t+=dt;
      } 
      break;
    case CONTEXT_SENSITIVE_JOINT:
      t=0;
      for (int i=0;i<n-ncs;i+=ncs)
        steps(CONTEXT_FREE_JOINT_FORWARD,i,i+ncs,x,xa,p,pa,dW_j);
      steps(CONTEXT_SENSITIVE_JOINT,n-ncs,n,x,xa,p,pa,dW_j);
      T y=x;
      for (int i=n-2*ncs;i>=0;i-=ncs)
        steps(CONTEXT_FREE_JOINT_BACKWARD,i,i+ncs,x,xa,p,pa,dW_j);
      x=y;
  }
}

void f_a(const int ncs, double& x, double& xa, 
    const vector<double>& p, vector<double>& pa,
    const vector<vector<double>>& dW) {
  int m=dW.size();
  // augmented primal
  double s=0, x0=x; 
  for (int j=0;j<m;j++) {
    x=x0;
    path(PRIMAL,ncs,x,xa,p,pa,dW[j]);
    s+=x; 
  }   
  x=s/m;
  double y=x;
  // adjoint
  double sa=0,x0a=0;
  sa+=xa/m; xa=0;
  for (int j=m-1;j>=0;j--) {
    x=x0; xa+=sa;
    path(CONTEXT_SENSITIVE_JOINT,ncs,x,xa,p,pa,dW[j]);
    x0a+=xa; xa=0;
  }
  xa+=x0a; x0a=0;
  x=y;
} 

vector<double> driver(const int ncs, double& x, vector<double>& p,
    const vector<vector<double>>& dW) {
  int n=dW[0].size();
  vector<double> g(n+1,0);
  double xa=1; vector<double> pa(n,0);
  f_a(ncs,x,xa,p,pa,dW);
  g[0]=xa;
  for (int i=0;i<n;i++) g[i+1]=pa[i];
  return g;
}  

int main(int c, char* v[]) {
  assert(c==4); int m=atoi(v[1]), n=atoi(v[2]), ncs=atoi(v[3]);
  const double x0=1;
  vector<double> p(n,1); 
  default_random_engine generator;
  normal_distribution<double> distribution(0.0,1.0);
  vector<vector<double>> dW(m,vector<double>(n,1));
  for (int i=0;i<m;i++)
    for (int j=0;j<n;j++)
      dW[i][j]=distribution(generator);
  double x=x0;
  vector<double> g=driver(ncs,x,p,dW);
  cout << "dx/dx0=" << g[0] << endl;
  for (int i=0;i<n;i++) 
    cout << "dx/dp[" << i << "]=" << g[i+1] << endl;
  return 0;
}
