#include "std_includes.h"

enum Mode { PRIMAL, CONTEXT_FREE_JOINT_ADJOINT };

void path(Mode mode, const int n,
  double& x, double& xa, 
  const vector<double>& p, vector<double>& pa,
  const vector<double>& dW_j) {
  double t=0, dt=1.0/n;
  switch (mode) {
    case PRIMAL:
      for (int i=0;i<n;i++) {
        x+=dt*p[i]*sin(x*t)+p[i]*cos(x*t)*sqrt(dt)*dW_j[i];
        t+=dt;
      } 
      break;
    case CONTEXT_FREE_JOINT_ADJOINT:
      stack<double> tbr;
      // augmented primal 
      t=0;
      for (int i=0;i<n;i++) {
        tbr.push(x);
        x+=dt*p[i]*sin(x*t)+p[i]*cos(x*t)*sqrt(dt)*dW_j[i];
        t+=dt;
      } 
      // adjoint 
      t=1; 
      for (int i=n-1;i>=0;i--) {
        t-=dt;
        x=tbr.top(); tbr.pop();
        pa[i]+=(dt*sin(x*t)+cos(x*t)*sqrt(dt)*dW_j[i])*xa;
        xa=(1+dt*p[i]*t*cos(x*t)-p[i]*t*sin(x*t)*sqrt(dt)*dW_j[i])*xa;
      }
  }
}

void f_a(double& x, double& xa, 
    const vector<double>& p, vector<double>& pa,
    const vector<vector<double>>& dW) {
  int m=dW.size(), n=dW[0].size();
  // augmented primal
  double s=0, x0=x; 
  for (int j=0;j<m;j++) {
    x=x0;
    path(PRIMAL,n,x,xa,p,pa,dW[j]);
    s+=x; 
  }   
  x=s/m;
  double y=x;
  // adjoint
  double sa=0,x0a=0;
  sa+=xa/m; xa=0;
  for (int j=m-1;j>=0;j--) {
    x=x0; xa+=sa;
    path(CONTEXT_FREE_JOINT_ADJOINT,n,x,xa,p,pa,dW[j]);
    x0a+=xa; xa=0;
  }
  xa+=x0a; x0a=0;
  x=y;
} 

vector<double> driver(double& x, vector<double>& p,
    const vector<vector<double>>& dW) {
  int n=dW[0].size();
  vector<double> g(n+1,0);
  double xa=1; vector<double> pa(n,0);
  f_a(x,xa,p,pa,dW);
  g[0]=xa;
  for (int i=0;i<n;i++) g[i+1]=pa[i];
  return g;
}

int main(int c, char* v[]) {
  assert(c==3); int m=atoi(v[1]), n=atoi(v[2]);
  const double x0=1;
  vector<double> p(n,1);
  default_random_engine generator;
  normal_distribution<double> distribution(0.0,1.0);
  vector<vector<double>> dW(m,vector<double>(n,1));
  for (int i=0;i<m;i++)
    for (int j=0;j<n;j++)
      dW[i][j]=distribution(generator);
  double x=x0;
  vector<double> g=driver(x,p,dW);
  cout << "dx/dx0=" << g[0] << endl;
  for (int i=0;i<n;i++)
    cout << "dx/dp[" << i << "]=" << g[i+1] << endl;
  return 0;
}
