function [p,x] = densitySP(model,x,x0,dt,theta,method)

% DENSITYSP  Estimate the transition density for an SDE using the
%    saddlepoint method described in [1].
%
%    Inputs:
%    model: one of {'bessel','sin','blackscholes','cir','ginzland',
%            'ou','log', 'gcir'} [see README.txt for how to define other
%            models]
%    x: value at the right-hand side of interval [scalar or vector]
%    x0: value at the left-hand side of interval [scalar]
%    dt: interval size
%    theta: parameter vector
%    method: one of {'scheme3_u','scheme4_t'} 
%
%    Example: 
%    x = linspace(.05,.15,100);
%    f = densitySP('CIR',x,0.1,1/12,[.5 .06 .15],'scheme4_t');
%    plot(x,f);
%
% Reference:
%    [1] Preston, S.P and Wood, A.T.A (2012) Approximation of transition
%        densities of stochastic differential equations by saddlepoint 
%        methods applied to small-time Ito-Taylor sample-path expansions,
%        Stat. Comput., 22 (2012), pp. 205-217

% Written by Simon Preston (http://www.maths.nott.ac.uk/~spp), 2009


% use the scheme4_t method by default
if nargin < 6, method = 'scheme4_t'; end

% transform data if 'method' requires it
if strcmpi(method(end),'t')
   [y,y0,sigmaOfx] = gammaTransform(model,x,x0,theta);
else
   [y,y0,sigmaOfx] = deal(x,x0,1);
end
   
numPoints = length(y);
sHat = zeros(1,numPoints);
fHat = zeros(1,numPoints);

c = ItoTaylorCoeffs(model,y0,dt,theta,method);

for iPoint = 1:numPoints
   sHat(iPoint) = solveSaddle(y(iPoint)-y0-c(end),c(1:end-1),dt,method);
   fHat(iPoint) = calc_fHat(y(iPoint)-y0-c(end),c(1:end-1),dt,...
      sHat(iPoint),method);
end

p = fHat./sigmaOfx;


% -------------------------------------------------------------------------

function out = solveSaddle(x,c,dt,method)

K = str2func(['K', method]);

switch lower(method)
    
   case 'scheme3_u'
      % the SP equation is a 3rd-order polynomial whose roots can be found using
      % matlab's function 'roots' 
      p = [4 * c(3) ^ 2 * dt ^ 5 * c(2) ^ 2,...
         -7 * c(3) ^ 2 * dt ^ 4 * c(2) - 12 * c(1) ^ 2 * c(2) * dt ^ 2 - 12 * ...
         c(1) * c(3) * dt ^ 3 * c(2) - 48 * x * c(2) ^ 2 * dt ^ 2,...
         -24 * c(2) ^ 2 * dt ^ 2 + 12 * c(1) ^ 2 * dt + 12 * c(1) * c(3) * ...
         dt ^ 2 + 4 * c(3) ^ 2 * dt ^ 3 + 48 * x * c(2) * dt,...
         12 * c(2) * dt - 12 * x];
      r = roots(p);
      if c(2)~=0, crit = 1/(2*c(2)*dt); else crit = NaN; end
      interv = findIntervalWithSignChange(K,x,c,dt,crit);
      out = r(r>interv(1) & r<interv(2));
      
   case 'scheme4_t'
      if c(4)~=0, crit = pi^2/(8*c(4)*dt^2); else crit = NaN; end
      interv = findIntervalWithSignChange(K,x,c,dt,crit);
      %out = fzero(@(s) K(1,c,dt,s) - x, interv, ...
      %   optimset('Display','notify','TolX',eps));
      try
         out = fzero(@(s) K(1,c,dt,s) - x, interv);
      catch  %#ok<*CTCH>
         out = NaN;
      end
      
end


% -------------------------------------------------------------------------

function out = calc_fHat(x,c,dt,sHat,method)

K = str2func(['K', method]);

out = (2*pi*K(2,c,dt,sHat)).^(-1/2) .* ...
         exp(K(0,c,dt,sHat) - sHat.*x);

