#include "std_includes.h"

template<typename T>
void f_a(T& x, T& xa, const vector<T>& p, vector<T>& pa,
    const vector<vector<double>>& dW) {
  int m=dW.size(), n=dW[0].size();
  stack<T> tbr_T; stack<double> tbr_double;
  // augmented primal
  T s=0, x0=x; double dt=1./n, t;
  for (int j=0;j<m;j++) {
    t=0;
    for (int i=0;i<n;i++) {
      tbr_T.push(x);
      x+=dt*p[i]*sin(x*t)+p[i]*cos(x*t)*sqrt(dt)*dW[j][i];
      tbr_double.push(t);
      t+=dt;
    } 
    s+=x; x=x0;
  }   
  x=s/m;
  T y=x;
  // adjoint
  T sa=0, x0a=0;
  sa+=xa/m; xa=0;
  for (int j=m-1;j>=0;j--) {
    x0a+=xa; xa=0;
    xa+=sa; 
    for (int i=n-1;i>=0;i--) {
      t=tbr_double.top(); tbr_double.pop();
      x=tbr_T.top(); tbr_T.pop();
      pa[i]+=(dt*sin(x*t)+cos(x*t)*sqrt(dt)*dW[j][i])*xa;
      xa=(1+dt*p[i]*t*cos(x*t)-p[i]*t*sin(x*t)*sqrt(dt)*dW[j][i])*xa;
    }
  }
  xa+=x0a; x0a=0; 
  x=y;
} 

vector<double> driver(double& x, vector<double>& p,
    const vector<vector<double>>& dW) {
  int n=dW[0].size();
  vector<double> g(n+1,0);
  double xa=1; vector<double> pa(n,0);
  f_a(x,xa,p,pa,dW);
  g[0]=xa;
  for (int i=0;i<n;i++) g[i+1]=pa[i];
  return g;
}  

int main(int c, char* v[]) {
  assert(c==3); int m=atoi(v[1]), n=atoi(v[2]);
  const double x0=1;
  vector<double> p(n,1); 
  default_random_engine generator;
  normal_distribution<double> distribution(0.0,1.0);
  vector<vector<double>> dW(m,vector<double>(n,1));
  for (int i=0;i<m;i++)
    for (int j=0;j<n;j++)
      dW[i][j]=distribution(generator);
  double x=x0;
  vector<double> g=driver(x,p,dW);
  cout << "dx/dx0=" << g[0] << endl;
  for (int i=0;i<n;i++) 
    cout << "dx/dp[" << i << "]=" << g[i+1] << endl;
  return 0;
}
