from __future__ import annotations
import matplotlib.pyplot as plt
import numpy as np
from multinterp.rectilinear._multi import MultivariateInterp
def trig_func(x, y):
return y * np.sin(x) + x * np.cos(y)
def trig_func_dx(x, y):
return y * np.cos(x) + np.cos(y)
def trig_func_dy(x, y):
return np.sin(x) - x * np.sin(y)
x_grid = np.geomspace(1, 11, 1000) - 1
y_grid = np.geomspace(1, 11, 1000) - 1
x_mat, y_mat = np.meshgrid(x_grid, y_grid, indexing="ij")
z_mat = trig_func(x_mat, y_mat)
x_new, y_new = np.meshgrid(
np.linspace(0, 10, 1000),
np.linspace(0, 10, 1000),
indexing="ij",
)
mult_interp = MultivariateInterp(z_mat, [x_grid, y_grid], backend="cupy")
z_mult_interp = mult_interp(x_new, y_new).get()
z_true = trig_func(x_new, y_new)
# Create a figure with two subplots
fig = plt.figure(figsize=(12, 6))
# Plot the interpolated function
ax1 = fig.add_subplot(1, 2, 1, projection="3d")
ax1.plot_surface(x_new, y_new, z_mult_interp)
ax1.set_title("Interpolated Function")
# Plot the true function
ax2 = fig.add_subplot(1, 2, 2, projection="3d")
ax2.plot_surface(x_new, y_new, z_true)
ax2.set_title("True Function")
plt.show()
![<Figure size 1200x600 with 2 Axes>](https://cdn.curvenote.com/616d93c9-d385-465a-a37f-8e2dce3e5a1b/public/6dbc9e38c5450da29ddd91a50e54fe21.png)
dfdx = mult_interp.diff(0)
z_dfdx = dfdx(x_new, y_new).get()
dfdx_true = trig_func_dx(x_new, y_new)
# Create a figure with two subplots
fig = plt.figure(figsize=(12, 6))
# Plot the interpolated function
ax1 = fig.add_subplot(1, 2, 1, projection="3d")
ax1.plot_surface(x_new, y_new, z_dfdx)
ax1.set_title("Interpolated Function")
# Plot the true function
ax2 = fig.add_subplot(1, 2, 2, projection="3d")
ax2.plot_surface(x_new, y_new, dfdx_true)
ax2.set_title("True Function")
plt.show()
![<Figure size 1200x600 with 2 Axes>](https://cdn.curvenote.com/616d93c9-d385-465a-a37f-8e2dce3e5a1b/public/b386d619b15a411a19232a6931c926d5.png)
dfdy = mult_interp.diff(1)
z_dfdy = dfdy(x_new, y_new).get()
dfdy_true = trig_func_dy(x_new, y_new)
# Create a figure with two subplots
fig = plt.figure(figsize=(12, 6))
# Plot the interpolated function
ax1 = fig.add_subplot(1, 2, 1, projection="3d")
ax1.plot_surface(x_new, y_new, z_dfdy)
ax1.set_title("Interpolated Function")
# Plot the true function
ax2 = fig.add_subplot(1, 2, 2, projection="3d")
ax2.plot_surface(x_new, y_new, dfdy_true)
ax2.set_title("True Function")
plt.show()
![<Figure size 1200x600 with 2 Axes>](https://cdn.curvenote.com/616d93c9-d385-465a-a37f-8e2dce3e5a1b/public/078db127ca96915eb8165d2384160b04.png)