#This python script simulates the three period student loan model from "The Self Perpetuating Student Loan Crisis"

import math
import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 16})
import copy
import numpy as np
import pandas as pd #get 'er done
import random
import statsmodels.formula.api as sm 
from scipy import optimize

class ModelFundamentals:

    def __init__(self,h,T,mybeta,r,cancel_param1,cancel_param2):
        self.h = h
        self.T = T
        self.mybeta = mybeta 
        self.r = r
        self.gross_r = float(1) + r
        self.cancel_param1 = cancel_param1
        self.cancel_param2 = cancel_param2
        self.choose_college() #Check for parameter violation

    def choose_college(self):
        # this is the borrowing balance at the beginning of period 2 

        if self.cons() < 1:
            print 'Warning! individuals do not choose college!'

        return 0 


    def balance2(self,p):
        # this is the borrowing balance at the beginning of period 2 

        bal_orig = (1 / self.gross_r + (1 / self.gross_r)**2 ) * (self.h + self.T) / (1 +  1 / self.gross_r + 1 / (self.gross_r**2) ) 
        bal_pterm = self.cons() * ((float(p) * (1 / (self.gross_r) + (1 / self.gross_r)**2)) / (1 + 1 / self.gross_r + (1 / self.gross_r)**2 - float(p) * (1 / (self.gross_r) + (1 / self.gross_r)**2)))
        bal_final = bal_orig + bal_pterm

        return bal_final

    def balance3(self,p):
        # this is the borrowing balance at the beginning of period 3 

        bal_orig = (1 / self.gross_r) * (self.h + self.T) / (1 +  1 / self.gross_r + 1 / self.gross_r**2 ) 
        bal_pterm = self.cons() * (float(p) / (self.gross_r**2)) / (1 / self.gross_r + 1 / self.gross_r**2 - float(p) / (self.gross_r**2))

        #can't borrow more
        #bal_final = min(bal_orig + bal_pterm,self.gross_r * self.balance2())
        bal_final = bal_orig + bal_pterm

        return bal_final

    def cancel_prob(self,b):
        #probability of canceling debt

        #res = (b3 / float(self.cancel_lim)) ** self.alp
        res = 1 / (1 + math.exp(- self.cancel_param2 * (b - self.cancel_param1)))

        return res

    def eq_func(self,p):
        #this is the function for which a fixed point is an equilibrium

        b2 = self.balance2(p)
        b3 = self.balance3(p)
        bbar = 0.5 * b2 + 0.5 * b3
        val = self.cancel_prob(bbar)

        return val

    def cons(self):
        #consumption vector, should be changing at rate (1+r) beta
        
        cbar = ((1 / self.gross_r + 1 / self.gross_r ** 2) * self.h - self.T) / (1 + 1/self.gross_r + 1 / self.gross_r ** 2)
        #c1 = self.balance2(0) - self.T
        #c2 = self.balance3(0) + self.h - self.r * self.balance2()
        #c3 = self.h - self.balance3(0) * self.r

        #diff1 = c2 - c1 * self.r * self.mybeta
        #diff2 = c3 - c2 * self.r * self.mybeta

        return cbar

        print 'Difference in expected consumption (should be zero) period 1 to 2 is: ' + str(diff1)
        print 'Difference in expected consumption (should be zero) period 2 to 3 is: ' + str(diff2)


if __name__ == '__main__':
# main entry

    #seed rngs
    random.seed(80085)
    np.random.seed(80085)

    # enter parameters here
    h = 1.85  #college wage premium
    T = 0.12 #tuition
    r = 0.256 # amount one must pay back
    mybeta = 1 / (float(1) + r) #discount factor
    cancel_param1 = 1.345 #midpoint of logistic curve (current 1.3) (old 10)
    cancel_param2 = 3.5 #curvature of logistic cure (current 4.0) (old 1)

    #read in class that holds basic parameters
    model = ModelFundamentals(h,T,mybeta,r,cancel_param1,cancel_param2)

    model.cons()

    #solve for equilibrium fixed point bailout probability
    sol  = optimize.fixed_point(model.eq_func,0.9)
    print "The solution prob of debt cancellation is " + str(sol)
    print "The original prob of debt cancellation is " + str(model.cancel_prob(0.5 * model.balance3(0) + 0.5 * model.balance2(0)))
    print "The solution average debt balance is " + str(0.5 * model.balance3(sol) + 0.5 * model.balance2(sol))
    print "The original average debt balance is " + str(0.5 * model.balance3(0) + 0.5 * model.balance2(0))

    #plot 
    yvals = [model.eq_func(x) for x in np.linspace(0.001,0.999,100)]
    #yvals = [model.cancel_prob(x) for x in np.linspace(0,2,100)]
    #yvals = [model.balance3(x) for x in np.linspace(0,1,100)]
    xvals = np.linspace(0.001,0.999,100)
    plt.plot(xvals,yvals,color='blue')
    plt.plot(xvals,xvals,color='k',linestyle='dashed')
    plt.xlabel("initial debt cancellation prob")
    plt.ylabel("induced debt cancellation prob")
    plt.xlim(0,1)
    plt.savefig("equilibrium_targeted.png")

