#include "std_includes.h"

enum Mode { AUGMENTED_PRIMAL, SPLIT_ADJOINT };

template <typename AT, typename PT>
inline void step_a(Mode mode, const int m, 
  const vector<PT>& p, vector<PT>& p_a, 
  vector<AT>& y, vector<AT>& y_a) 
{
  int n=y.size();
  static stack<vector<AT>> tbr;
  vector<AT> r(n), r_a(n,0);
  int ns=(n+1)*(n+1); AT v=p[0]*ns;
  switch (mode) {
  case AUGMENTED_PRIMAL:
    r[0]=v*(p[1]-2*y[0]+y[1]);
    for (int i=1;i<n-1;i++) 
      r[i]=v*(y[i-1]-2*y[i]+y[i+1]);
    r[n-1]=v*(y[n-2]-2*y[n-1]+p[2]);
    tbr.push(y);
    for (int i=0;i<n;i++) y[i]+=r[i]/m;
    break;
  case SPLIT_ADJOINT:
    y=tbr.top(); tbr.pop();
    for (int i=0;i<n;i++) r_a[i]+=y_a[i]/m;
    p_a[0]+=ns*(p[1]-2*y[0]+y[1])*r_a[0];
    p_a[1]+=v*r_a[0]; y_a[0]-=v*2*r_a[0];
    y_a[1]+=v*r_a[0]; r_a[0]=0;
    for (int i=1;i<n-1;i++) {
      p_a[0]+=ns*(y[i-1]-2*y[i]+y[i+1])*r_a[i];
      y_a[i-1]+=v*r_a[i]; y_a[i]-=v*2*r_a[i];
      y_a[i+1]+=v*r_a[i]; r_a[i]=0;
    }
    p_a[0]+=ns*(y[n-2]-2*y[n-1]+p[2])*r_a[n-1];
    y_a[n-2]+=v*r_a[n-1]; y_a[n-1]-=v*2*r_a[n-1];
    p_a[2]+=v*r_a[n-1]; r_a[n-1]=0;
    break;
  }   
}

template <typename AT, typename PT>
inline void f_a(const int m, 
  const vector<PT>& p, vector<PT>& p_a, 
  vector<AT>& y, vector<AT>& y_a) 
{
  for (int j=0;j<m;j++) step_a(AUGMENTED_PRIMAL,m,p,p_a,y,y_a);
  for (int j=0;j<m;j++) step_a(SPLIT_ADJOINT,m,p,p_a,y,y_a);
}

int main(int c, char* v[]) {
  assert(c==3);
  int n=atoi(v[1]), m=atoi(v[2]);
  vector<double> y(n), y_a(n,0);
  for (int i=0;i<n;i++) y[i]=(i+1)*log(static_cast<double>(i+2));
  vector<double> p={1e-3,42,0}, p_a(3,0);
  y_a[n/2]=1;
  f_a(m,p,p_a,y,y_a);
  for (int i=0;i<n;i++) 
    cout << "dy(n/2)/dy0[" << i << "]=" << y_a[i] << endl;
  return 0;
}
