{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-06-21T00:28:49.748311944Z",
     "start_time": "2023-06-21T00:28:49.744946948Z"
    },
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
     ]
    }
   ],
   "source": [
    "\n",
    "from pytelem.optimus import *\n",
    "from jax import jit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-06-21T00:28:50.052452543Z",
     "start_time": "2023-06-21T00:28:50.050159245Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'optim' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[2], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m ffast \u001b[39m=\u001b[39m jit(optim\u001b[39m.\u001b[39msolar_position)\n",
      "\u001b[0;31mNameError\u001b[0m: name 'optim' is not defined"
     ]
    }
   ],
   "source": [
    "ffast = jit(optim.solar_position)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-06-21T00:29:01.075229307Z",
     "start_time": "2023-06-21T00:29:01.042571064Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "timestamp = np.array(1687306901)\n",
    "timestamps = np.array([1687306901, 1687306902, 1687306903, 1687306906])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "\n",
    "jd = timestamps / 86400.0 + 2440587.5\n",
    "jde = jd + DELTA_T / 86400.0\n",
    "jce = (jde - 2451545) / 36525\n",
    "\n",
    "jme = jce / 10\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# todo: make more elegant?\n",
    "# TODO: vectorize later? it's kinda complex\n",
    "\n",
    "# heliocentric longitude\n",
    "l_rad = np.zeros_like(timestamp)\n",
    "for idx, vec in enumerate(HELIO_L):\n",
    "    l_rad = l_rad + helio_vector(vec, jme) * jme ** idx\n",
    "\n",
    "l_rad = l_rad / 10e8\n",
    "l_deg = np.rad2deg(l_rad) % 360\n",
    "\n",
    "# heliocentric latitude\n",
    "b_rad = np.zeros_like(timestamp)\n",
    "for idx, vec in enumerate(HELIO_B):\n",
    "    b_rad = b_rad + helio_vector(vec, jme) * jme ** idx\n",
    "b_rad = b_rad / 10e8\n",
    "b_deg = b_rad % 360\n",
    "\n",
    "# heliocentric radius\n",
    "r_rad = np.zeros_like(timestamp)\n",
    "for idx, vec in enumerate(HELIO_R):\n",
    "    r_rad = r_rad + helio_vector(vec, jme) * jme ** idx\n",
    "r_rad = r_rad / 10e8\n",
    "r_deg = r_rad % 360\n",
    "\n",
    "theta = (l_deg + 180) % 360\n",
    "beta = -1 * b_deg\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(4, 5)\n"
     ]
    }
   ],
   "source": [
    "\n",
    "def cubic_poly(a,b,c,d):\n",
    "    return a + b * jce + c * jce ** 2 + (jce ** 3) / d\n",
    "X0 = cubic_poly(297.85036, 445267.111480, -0.0019142, 189474)\n",
    "X1 = cubic_poly(357.52772, 35999.050340, -0.0001603, -300000)\n",
    "X2 = cubic_poly(134.96298, 477198.867398, 0.0086972, 56250)\n",
    "X3 = cubic_poly(93.27191, 483202.017538, -0.0036825, 327270)\n",
    "X4 = cubic_poly(125.04452, 1934.136261, 0.0020708, 450000)\n",
    "\n",
    "X = np.vstack([X0, X1, X2, X3, X4]).T\n",
    "print(X.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "nut = NUTATION_ABCD_ARRAY\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(4, 63)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(nut[:,0] + jce[..., np.newaxis] * nut[:,1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(4, 5)\n",
      "(63, 5)\n",
      "d psi shape (4, 63)\n",
      "[-0.00333032 -0.00333032 -0.00333032 -0.00333032]\n"
     ]
    }
   ],
   "source": [
    "\n",
    "print(X.shape)\n",
    "print(NUTATION_YTERM_ARRAY.shape)\n",
    "d_psi = (nut[:,0] + jce[..., np.newaxis] * nut[:,1]) * np.sin(np.sum(X[:, np.newaxis, :] * NUTATION_YTERM_ARRAY[np.newaxis, ...], axis=-1))\n",
    "\n",
    "print(f\"d psi shape {d_psi.shape}\")\n",
    "\n",
    "nut_long = np.sum(d_psi, axis=-1) / 36,000,000\n",
    "\n",
    "print(nut_long)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(4, 11)\n"
     ]
    }
   ],
   "source": [
    "u = timestamps[:, np.newaxis] ** np.arange(0,11)\n",
    "print(u.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "epsilon_0 = np.array(\n",
    "        [[\n",
    "            84381.448,\n",
    "            -4680.93,\n",
    "            1.55,\n",
    "            1999.25,\n",
    "            -51.38,\n",
    "            -249.67,\n",
    "            -39.05,\n",
    "            7.12,\n",
    "            27.87,\n",
    "            5.79,\n",
    "            2.45,\n",
    "        ]]\n",
    "    )\n",
    "res = u * epsilon_0\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([-8.6115017e+12, -5.7323689e+12, -6.4693841e+12, -1.1346544e+13],      dtype=float32)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res.sum(axis=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}