#include "std_includes.h"
#include "f.h"

enum Mode { AUGMENTED_PRIMAL, SPLIT_ADJOINT };

template <typename T, int N=Dynamic>
inline void step_a(Mode mode, const int m, const Matrix<T,3,1>& p, 
                   Matrix<T,N,1>& y, Matrix<T,N,1>& y_a) {
  static stack<Matrix<T,N,1>> psols;
  int n=y.size();
  Matrix<T,N,N> A=Matrix<T,N,N>::Zero(n,n); 
  switch (mode) {
  case AUGMENTED_PRIMAL: {
    Matrix<T,N,1> y_prev=Matrix<T,N,1>::Zero(n);
    y_prev=y;
    newton(m,p,y_prev,y); 
    psols.push(y);
    break;
  }
  case SPLIT_ADJOINT: {
    y=psols.top(); psols.pop();
    dfdy(m,p,y,A);
    PartialPivLU<Matrix<T,N,N>> LU(A.transpose());
    y_a=LU.solve(y_a); 
    break;
  }
  }  
}

template <typename T, int N=Dynamic>
inline void f_a(const int m, const Matrix<T,3,1>& p, Matrix<T,N,1>& y, 
                Matrix<T,N,1>& y_a) {
  for (int j=0;j<m;j++) 
    step_a(AUGMENTED_PRIMAL,m,p,y,y_a);
  for (int j=0;j<m;j++) 
    step_a(SPLIT_ADJOINT,m,p,y,y_a);
}

int main(int c, char* v[]){
  assert(c==3);
  int n=atoi(v[1]), m=atoi(v[2]);
  Matrix<double,Dynamic,1> y(n); 
  for (int i=0;i<n;i++) y(i)=(i+1)*log(static_cast<double>(i+2));
  Matrix<double,3,1> p; p(0)=1e-3; p(1)=42; p(2)=0;
  Matrix<double,Dynamic,1> y_a=Matrix<double,Dynamic,1>::Zero(n); 
  y_a(n/2)=1; 
  f_a(m,p,y,y_a);  
  for(int i=0;i<n;i++)
    cout << "dy(n/2)/dy0[" << i << "]=" << y_a(i) << endl;
  return 0;
}


