from __future__ import annotations

import numpy as np
from matplotlib import pyplot as plt

from multinterp.unstructured._gpytorch import GaussianProcessRegression

%matplotlib inline
%load_ext autoreload
%autoreload 2
def function_1(x, y):
    return x * (1 - x) * np.cos(4 * np.pi * x) * np.sin(4 * np.pi * y**2) ** 2
rng = np.random.default_rng(0)
rand_x, rand_y = rng.random((2, 1000))
values = function_1(rand_x, rand_y)
grid_x, grid_y = np.meshgrid(
    np.linspace(0, 1, 100), np.linspace(0, 1, 100), indexing="ij"
interp = GaussianProcessRegression(
    values.astype("float32"), (rand_x.astype("float32"), rand_y.astype("float32"))
new_grid = interp(grid_x.astype("float32"), grid_y.astype("float32"))
Iter 1/50 - Loss: -1.057   lengthscale: 0.704   noise: 0.012
Convergence reached!
plt.imshow(function_1(grid_x, grid_y).T, extent=(0, 1, 0, 1), origin="lower")
plt.plot(rand_x, rand_y, "ok", ms=2, label="input points")
plt.legend(loc="lower right")
plt.imshow(new_grid.mean.cpu(), extent=(0, 1, 0, 1), origin="lower")
plt.plot(rand_x, rand_y, "ok", ms=2, label="input points")
plt.legend(loc="lower right")
