"""
Linear regression
^^^^^^^^^^^^^^^^^

This example begins with linear regression in the real domain and then builds up
to show how linear problems can be thought of in the complex domain.

The Wikipedia entry is a useful starting point for those unfamiliar with linear
regression.

- https://en.wikipedia.org/wiki/Linear_regression

Useful references:

- https://stats.stackexchange.com/questions/66088/analysis-with-complex-data-anything-different
- https://www.chrishenson.net/article/complex_regression


Begin the example by importing packages/functions and switching off the logger
to avoid superfluous messages.
"""
# sphinx_gallery_thumbnail_number = 5
from loguru import logger
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from regressioninc.linear.models import add_intercept, OLS
from regressioninc.testing.complex import ComplexGrid

logger.remove()

# %%
# One of the most straightforward linear problems to understand is the equation
# of a line. Let's look at a line with gradient 3 and intercept -2.
params = np.array([3])
intercept = -2
X = np.arange(-5, 5).reshape(10, 1)
y = np.matmul(X, params) + intercept

# %%
# The above code is essentially doing the following:
#
# .. math::
#
#   y = mx + c
#
# where :math:`m = 3` and :math:`c = -2` with the regressor going from -5 to 4
# in steps of 1.
#
# Let's have a quick look at the values, noting that there is only a single
# regressor, which is in the first column (index 0) of X. The variable name X
# represents the values of one or more regressors and is usually referred to as
# the plural regressors.
print(f"Regressors X {X[:, 0]}")
print(f"Regrassand y {y}")

# %%
# And now plot the relationship between the regressors X and regrassand y.
fig = plt.figure()
plt.scatter(X[:, 0], y)
plt.xlabel("Independent variable")
plt.ylabel("Dependent variable")
plt.tight_layout()
fig.show()

# %%
# When performing linear regression, the aim is to:
#
# - calculate the parameters (also called coefficients)
# - given the regressors X (values of the independent variable)
# - and values of the regrassand y (values of the dependent variable)
#
# This can be done with linear regression, and the most common method of linear
# regression is least squares, which aims to estimate the parameters whilst
# minimising the squared misfit between the regrassands and predicted
# regrassands calculated using the estimated parameters.
#
# Before solving, let's add a row of 1s to our regressors X. This is to make
# sure a constant intercept is also solved for.
X = add_intercept(X)
print(X.T)

# %%
# Now use ordinary least squares to estimate the parameters.
model = OLS()
model.fit(X, y)
print(model.estimate.params)

# %%
# Least squares was able to correctly calculate the slope and intercept for the
# real-valued regression problem. Let's look at the predicted regrassands using
# the estimated parameters.
preds = model.predict(X)
print(preds)

# %%
# It is also possible to have linear problems in the complex domain. These
# commonly occur in signal processing problems. Let's define parameters and
# regressors X and generate the corresponding regrassand y for an example
# problem. For the time being, there is no intercept.
params = np.array([2 + 3j])
X = np.array([1 + 1j, 2 + 1j, 3 + 1j, 1 + 2j, 2 + 2j, 3 + 2j]).reshape(6, 1)
y = np.matmul(X, params)

# %%
# Let's have a quick look at the values, again noting that there is only a
# single regressor.
print(f"Regressors X {X[:, 0]}")
print(f"Regrassand y {y}")

# %%
# It is a bit harder to visualise the complex-valued version, but let's try and
# visualise the regressors X and regrassands y.
fig, axs = plt.subplots(nrows=1, ncols=2)
plt.sca(axs[0])
plt.scatter(X.real, X.imag, c="tab:blue")
plt.xlim(X.real.min() - 3, X.real.max() + 3)
plt.ylim(X.imag.min() - 3, X.imag.max() + 3)
plt.title("Regressors X")
plt.sca(axs[1])
plt.scatter(y.real, y.imag, c="tab:red")
plt.xlim(y.real.min() - 3, y.real.max() + 3)
plt.ylim(y.imag.min() - 3, y.imag.max() + 3)
plt.title("Regrassand y")
plt.show()

# %%
# Visualsing the regressors X and the regrassand y this way gives a geometric
# indication of the linear problem in the complex domain. Multiplying the
# regressors by the parameters can be considered like a scaling and a rotation
# of the independent variables to give the dependent variables y.
#
# With more samples, this can be a bit easier to visualise. In the below
# example, regressors and the regresand are generated again, this time with more
# samples. To start off with, the parameter is a real number to demonstrate the
# scaling without any rotation. Both the regressors and regrassand are plotted
# on the same axis with lines to show the mapping between independent and
# dependent values.
grid = ComplexGrid(r1=0, r2=10, nr=11, i1=-5, i2=5, ni=11)
X = grid.flat_grid()
params = np.array([0.5])
y = np.matmul(X, params)

fig = plt.figure()
for iobs in range(y.size):
    plt.plot(
        [y[iobs].real, X[iobs, 0].real],
        [y[iobs].imag, X[iobs, 0].imag],
        color="k",
        lw=0.5,
    )
plt.scatter(X.real, X.imag, c="tab:blue", label="Regressor")
plt.grid()
plt.title("Regressor X")
plt.scatter(y.real, y.imag, c="tab:red", label="Regrassand")
plt.grid()
plt.legend()
plt.title("Complex regression")
plt.show()

# %%
# Now let's add a complex component to the parameter (coefficient) to
# demonstrate the rotational aspect.
params = np.array([0.5 + 2j])
y = np.matmul(X, params)

fig = plt.figure()
for iobs in range(y.size):
    plt.plot(
        [y[iobs].real, X[iobs, 0].real],
        [y[iobs].imag, X[iobs, 0].imag],
        color="k",
        lw=0.5,
    )
plt.scatter(X.real, X.imag, c="tab:blue", label="Regressor")
plt.grid()
plt.title("Regressors X")
plt.scatter(y.real, y.imag, c="tab:red", label="Regrassand")
plt.grid()
plt.legend()
plt.title("Complex regression")
plt.show()

# %%
# Finally, adding an intercept gives a translation.
params = np.array([0.5 + 2j])
intercept = 20 + 20j
y = np.matmul(X, params) + intercept

fig = plt.figure()
for iobs in range(y.size):
    plt.plot(
        [y[iobs].real, X[iobs, 0].real],
        [y[iobs].imag, X[iobs, 0].imag],
        color="k",
        lw=0.3,
    )
plt.scatter(X.real, X.imag, c="tab:blue", label="Regressor")
plt.grid()
plt.title("Regressors X")
plt.scatter(y.real, y.imag, c="tab:red", label="Regrassand")
plt.grid()
plt.legend()
plt.title("Complex regression")
plt.show()


# %%
# Similar to the real-valued problem, linear regression can be used to estimate
# the values of the parameters for the complex-valued problem. Again, least
# squares is one of the most common methods of linear regression. However, not
# all least squares algorithms support complex data, though some do such as the
# least squares in Scipy. The focus of |pkgnm| is to provide regression methods
# for complex-valued data.
#
# Note that adding an intercept column to X allows for solving of the intercept.
# |pkgnm| does not automatically solve for the intercept and if desired, an
# intercept column needs to be added to the regressors X, similar to the
# real-valued example shown at the top of the page.
X = add_intercept(X)
model = OLS()
model.fit(X, y)
print(model.estimate.params)

# %%
# Finally, let's compare the actual regressand y to the predicted regrassand
# calculated from the regressors X and the estimated parameters.
preds = model.predict(X)
df = pd.DataFrame(
    data={
        "Regressor X": X[:, 0],
        "parameter": params[0],
        "intercept": intercept,
        "regressand y": y,
        "predicted y": preds,
    }
)
print(df)
