anisotropy/playground/training-models.ipynb

304 lines
238 KiB
Plaintext
Raw Normal View History

2021-10-04 16:08:29 +05:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "c054d421-9717-4feb-bb30-d080d91ceac5",
"metadata": {},
"outputs": [],
"source": [
"from anisotropy.core.database import Database, Structure\n",
"from pandas import DataFrame, Series\n",
"import matplotlib.pyplot as plt\n",
"import seaborn\n",
"import numpy"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "7ec2ae1b-8a7e-4d50-b1bd-201fc3cd190d",
"metadata": {},
"outputs": [],
"source": [
"db = Database(\"anisotropy\", \"woPrismaticLayer\")\n",
"db.setup()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "75498f51-fa01-494e-a64e-e650138d1ca2",
"metadata": {},
"outputs": [],
"source": [
"res = db.search([Structure.type == \"simple\", Structure.direction == str([1.0, 0.0, 0.0])])\n",
"df = DataFrame(res)"
]
},
{
"cell_type": "code",
"execution_count": 4,
2021-10-06 00:58:37 +05:00
"id": "3bc80bca-03ce-416e-be2f-ab7ae2af95fa",
2021-10-04 16:08:29 +05:00
"metadata": {},
"outputs": [],
"source": [
2021-10-06 00:58:37 +05:00
"df_numeric = df[[\n",
" col for col in df.columns \n",
" if not isinstance(df[col][0], str) \n",
" and not isinstance(df[col][0], numpy.bool_)\n",
" and not isinstance(df[col][0], dict)\n",
" and not isinstance(df[col][0], list)\n",
" and not df[col][0] is None\n",
" and not col[-3: ] == \"_id\"\n",
"]]"
2021-10-04 16:08:29 +05:00
]
},
{
"cell_type": "code",
2021-10-06 00:58:37 +05:00
"execution_count": 5,
"id": "eab92bcc-8f1f-4490-8dc6-8ebf191d6cf2",
2021-10-04 16:08:29 +05:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<AxesSubplot:>"
]
},
2021-10-06 00:58:37 +05:00
"execution_count": 5,
2021-10-04 16:08:29 +05:00
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABn8AAATSCAYAAAB/xfFzAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOzdd1xTVx8G8IcdNuJAFAUnIKBi3XsPEBEXKm7rRHFvW7Xu2VoHfbUq1j0QAbfUUQda9wIFUVCGoCJ7E94/oBGaaAVJQuLz/Xz4vMm9J/FJ3pNz0vzuPVclLy8vD0RERERERERERERERKQUVOUdgIiIiIiIiIiIiIiIiEoPiz9ERERERERERERERERKhMUfIiIiIiIiIiIiIiIiJcLiDxERERERERERERERkRJh8YeIiIiIiIiIiIiIiEiJsPhDRERERERERERERESkRFj8ISIiIiIiIiIiIiIiKqbVq1ejY8eOsLS0REhIiMQ2ubm5WLJkCTp37owuXbrgyJEjX7Tva6mX2jMRERERERERERERERF9Izp16oRhw4bBzc3tk238/f3x6tUrnDt3DgkJCejduzdatGgBMzOzz+77Wjzzh4iIiIiIiIiIiIiIqJgaN24MU1PTz7Y5deoU+vfvD1VVVRgbG6Nz5844c+bMf+77Wjzzh4iIiIiIiIiIiIiICEBSUhKSkpLEthsYGMDAwKDYzxcTE4MqVaqI7puamuLNmzf/ue9rsfhDn5X97oW8I5RIQl5FeUf4phipvJV3hBLTrtJG3hFKJC3slLwjfFMSBTXlHYFIqgwzFHO+16nlIO8IJRIT9VTeEb45Vc2s5R2hRIR5efKOUCKK2sf1rm6Rd4SSUVXQBT2EQnknIAWR0tpd3hG+KXrn1sg7QomYjN0n7wgllpadKe8IJZK8f4K8I5SIdr+F8o6gVBT1t+P/svvASWzevFls+6RJkzB58mQ5JCoZFn+IiIiIiIiIiIiIiIgADB8+HC4uLmLbS3LWD5B/Nk90dDTq168PoOjZPp/b97UU9BAhIiIiIiIiIiIiIiKi0mVgYAAzMzOxv5IWf7p3744jR45AKBQiPj4eAQEB6Nat23/u+1o884eIiIiIiIiIiIiIiKiYli1bhnPnzuHdu3cYOXIkjIyMcPLkSYwZMwYeHh6ws7ODs7MzHjx4gK5duwIA3N3dUa1aNQD47L6vxeIPEREREREREREREREVjzBX3gnkbuHChVi4UPxaUtu3bxfdVlNTw5IlSyQ+/nP7vhaXfSMiIiIiIiIiIiIiIlIiLP4QEREREREREREREREpERZ/iIiIiIiIiIiIiIiIlAiv+UNERERERERERERERMWTJ5R3AvoMnvlDRERERERERERERESkRFj8ISIiIiIiIiIiIiIiUiIs/hARERERERERERERESkRFn+IiIiIiIiIiIiIiIiUiLq8A9C3Zf9RPxw/FYDQFy/h0Lk9li+cIe9IVEYpcl+xsbHE2tU/olGj+qhQwRjqmlXlmicxKQU/rvdE4J2HMDLQx5TRg+HYqbVYu6SUVKzesgtXb90HALg6dcXE4QNE+7u5ueP9hwSoquYfN9DQxhLbVi9kdqJvzLfwuSxr4zjJh4fH95g5YyJ0dLRx7NhJTJo8H1lZWRLbjhw5CLNmuaOySUVcu34LY8fOQExMLABAU1MTGzYsgXOv7tDQ0EBg4C24T5qH6Og3snw5IuzfpScxLROLj15FYGg0yulqYXL37+DQsJbEtsFR77DW/28ER7+HtqY6RrevD7fWNgCAp9HvsdrvJkJj4qGjpYF+zSwxtlND6eY+cgWBIVH5uXs0gYP9J3JHvsNa/xsIjirI3bEB3FrbAgDuh8dirf8NvIxNRFVjPcx3aQn7GpWlm1tR32/mlllukr3E9Ews9ruNwBdvUE5HC5M72sHBzlxi2+CYD1h79h6CYxKgramG0a2t4dasLgDg+90XEfY2CVk5uahaThcT2tuig6X05yj3SaMwdfo4aGsL4Hv8DKZN+eGT8/2w4QMwfcZ4VDKpiBuBtzFx/By8eRNXpI2Ghgau3zgJPX1dWNdtJfX8n1LW5vvEtEwsPhaIwOfRKKcrwOSu9nBoUENi2+Co91h76jaCo+OhraGO0e1t4dbSGjEJqeiz0a9I2/SsHEzv8R2Gta4ni5dBkgiF8k5An8HiTxmxadMmjBs3Dpqampg7dy5sbW0xZMiQYj3HsWPHYG9vjxo1JA+eZUHFCuUxbsRAXLt5B5mZkidTIkCx+0p2dg6OHPWH5/92w8d7l7zjYPmm36Ghro5LR7bj6fNwuC9YCcta5qhtUa1IuzWeu5GemYUze7cgPiER389aClOTinDp3kHUZtPSOWjxXX1mJ/qGfQufy7I2jpPsdenSDrNmuqNbN1dEx8TiyOHfsejHGViwcKVY27ZtW2DpT3PQtesAhD5/iQ0blmDPH1vQuUs/AMDkyaPRvNl3+K5xFyQmJsNz62r88vNSDHAdI+uXBYD9uzSt9A2EhroqLiwciGcx8Zi86zzqmhqjtkm5Iu0+pGZg4s7zmNmzKbrYWSA7V4jYxFTR/vkHL6ODjTl+H9sd0R9SMPK3U6hraoz29apLJ7fPdWioqeLCj4PxLPo9Ju86l5+7soTcO85iplMzdKlfA9k5uYhNTAOQ/yPeFK/zWNCnFTrZmuPM/Rfw8DqPk3MGwEBHSzq5FfX9Zm6Z5ibZW3nqbv6YMqMXnr1JwOQDV1HXxAi1KxkWafchLRMT9/2Fmd0aoou1WX5fSUoX7Z/d3R41KxpAXVUVjyLfY9zey/B174GK+tpSy96pcxtMmzEePR3c8CYmFvsO/Ib5C6di8Y9rxNq2btMMi5bMhGMPN4Q9D8fqtT9ip9dGOHQfVKTdlKlj8O5dPPT0daWW+0uUtfl+pf/f+WPKvP54FvMBk/+4gLqVy6G2iVGRdh9SMzBx95+Y6dAYXWzNC8aU/LnH1EgXgYs+vt9R8clw2uCLTjYcT4g+hcu+lRGbN29Gdnb2Vz2Hj48PwsPDSyeQlHRp3wqd2raEkaGBvKNQGafIfSUkJAy7vA4iKChE3lGQlp6B81duYtJIV+hoC9DIzgrtWzaG//m/xNpeDryDUQN6QVughaqVK6FPjw44fuaiHFLnU+TsRMrqW/lclqVxnORj6JD+8PI6iKDgECQkJGLFyl8wdGh/iW0dHDrB+9gJBAWHIDs7GytWbETbts1Rs2b+Uc8WFtVw/vxlxMW9Q2ZmJo4c9UO9enVl+XKKYP8uHelZ2Qh4HAH3Lo2go6UBewsTtKtXHSfvhom13XPlCVrWrQpH+1rQVFeDrpYGalYyEu2P/pACh4a1oKaqimrlDdDQwgRhsQlSzB0O927f5eeuUbkg93Px3H89RkvLqnBsVDs/t0ATNQt+pHsQHovy+troWr8G1FRV4dioNsrpCvDn43Ap5lbU95u5ZZWbZC89KwcBwVFw72ALHU0N2FeviHZ1q+DkwwixtnsCn6FlrcpwtDP/2Fcqfvzv/bomRlAvOCMcKkBOrhBvktKkmn+wW1/8sfswngaHIiEhCWtWb4bbkL4S23bv0RHHj53G0+BQZGdnY82qTWjdphlq1PhYeDA3N4PrwN7YsN5Tqrm/RFma79OzshHw5BXcOzcsGFMqoZ21GU7efyHWds+1YLSsUwWODWsWGlMMJTwr4H/vBRpZVELVcnrSfglECovFnzJgyZIlAICBAwfC2dkZSUlJCAkJwbBhw9C1a1fMnj0beXl5AICUlBQsWLAA/fr1g5OTE5YtW4bc3Fx4e3vj8ePHWLZsGZydnXH9+nU8e/YMgwcPhouLCxwcHODl5SXHV0lE8hARGQN1NTVYmFURbbOsaY6wiNcS2+cVvp0HhIYXbTd35Sa07TsaY+csw7OwcCkk/kiRsxMpK34u6VtRr15dPHwYJLr/8GEQKleuBGNjI4ntVVRUxG7b2FgCAHbtOogWLRrD1NQE2toCDBrYB2fOKkYhlD4t4m0S1FVVYF7x4w9SdU3LSfxR+9GrOBhoa2LY1hPosPQAPLwCEJOQIto/uJUNTtx9juxcIcL
"text/plain": [
"<Figure size 2160x1440 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"seaborn.set(rc = { \"figure.figsize\": (30, 20) })\n",
"seaborn.heatmap(df_numeric.corr(), annot = True)"
]
},
{
"cell_type": "code",
2021-10-06 00:58:37 +05:00
"execution_count": 6,
"id": "7611c892-1e13-404e-849e-3cfa15d8dfca",
2021-10-04 16:08:29 +05:00
"metadata": {},
"outputs": [],
2021-10-06 00:58:37 +05:00
"source": [
"x = df_numeric[[\"theta\", \"r0\", \"L\", \"radius\"]] #.drop(columns = [\"flowRate\"])\n",
"y = df_numeric[\"flowRate\"]"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "1a38aaa0-89fb-40f7-abce-e500ab696349",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split\n",
"xtr, xte, ytr, yte = train_test_split(x, y, test_size = 0.2, random_state = 100)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "1b7e9f62-35a9-4a1a-9ffa-a2a6e757ff46",
"metadata": {},
"outputs": [],
"source": [
"from sklearn import preprocessing\n",
"scaler = preprocessing.MinMaxScaler()\n",
"x_scaled = scaler.fit_transform(xtr)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "19bd2f52-a5bf-404c-93f7-0626bc35a915",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DecisionTreeRegressor(random_state=500)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.tree import DecisionTreeRegressor\n",
"neigh = DecisionTreeRegressor(random_state = 500)\n",
"neigh.fit(x_scaled, ytr)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "d063265f-2f06-4290-811e-b6b1753bcac9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"2.664398739090909e-15"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.metrics import mean_absolute_error\n",
"xte_scaled = scaler.transform(xte)\n",
"y_pred = neigh.predict(xte_scaled)\n",
"mean_absolute_error(yte, y_pred)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "762af50c-371a-4ba5-b5ab-763970ef2a37",
"metadata": {},
"outputs": [],
"source": [
"#df_numeric[[\"theta\", \"r0\", \"L\", \"radius\", \"flowRate\", \"volumeCell\", \"volume\"]]"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "7a951647-efa1-4231-b96d-f67bc0ed6c33",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>theta</th>\n",
" <th>r0</th>\n",
" <th>L</th>\n",
" <th>radius</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.29</td>\n",
" <td>1.0</td>\n",
" <td>2.0</td>\n",
" <td>1.408451</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" theta r0 L radius\n",
"0 0.29 1.0 2.0 1.408451"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_df = DataFrame([{\n",
" \"theta\": 0.29,\n",
" \"r0\": 1.,\n",
" \"L\": 2.,\n",
" \"radius\": 1. / (1. - 0.29)\n",
"}]); test_df"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "bc9204db-bde4-4032-b4f5-313e99e25ab5",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([4.53058768e-15])"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"neigh.predict(scaler.transform(test_df))"
]
2021-10-04 16:08:29 +05:00
},
{
"cell_type": "code",
"execution_count": null,
2021-10-06 00:58:37 +05:00
"id": "4b95277f-c49e-42b9-8d5f-3334d5ca7729",
2021-10-04 16:08:29 +05:00
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
2021-10-06 00:58:37 +05:00
"id": "c56035b1-18d5-4c88-94b6-da4d7eccac9a",
2021-10-04 16:08:29 +05:00
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
2021-10-06 00:58:37 +05:00
"display_name": "Python 3 (ipykernel)",
2021-10-04 16:08:29 +05:00
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.0"
2021-10-04 16:08:29 +05:00
}
},
"nbformat": 4,
"nbformat_minor": 5
}