{ "cells": [ { "cell_type": "markdown", "id": "d68537ba", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "source": [ "(bart_categorical)=\n", "# Categorical regression\n", "\n", ":::{post} May, 2024\n", ":tags: BART, regression\n", ":category: beginner, reference\n", ":author: Pablo Garay, Osvaldo Martin\n", ":::" ] }, { "cell_type": "markdown", "id": "0cf4f392-fdc7-4175-9e72-c8a334abea84", "metadata": {}, "source": [ "In this example, we will model outcomes with more than two categories. \n", ":::{include} ../extra_installs.md\n", ":::" ] }, { "cell_type": "code", "execution_count": 1, "id": "7c087cca", "metadata": {}, "outputs": [], "source": [ "import os\n", "import warnings\n", "\n", "import arviz as az\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "import pymc as pm\n", "import pymc_bart as pmb\n", "import seaborn as sns\n", "\n", "warnings.simplefilter(action=\"ignore\", category=FutureWarning)" ] }, { "cell_type": "code", "execution_count": 2, "id": "25cf7b45", "metadata": {}, "outputs": [], "source": [ "# set formats\n", "RANDOM_SEED = 8457\n", "az.style.use(\"arviz-darkgrid\")" ] }, { "cell_type": "markdown", "id": "e73740d8-8e70-48b4-b6f9-eb0c1f7ce72f", "metadata": {}, "source": [ "## Hawks dataset \n", "\n", "Here we will use a dataset that contains information about 3 species of hawks (*CH*=Cooper's, *RT*=Red-tailed, *SS*=Sharp-Shinned). This dataset has information for 908 individuals in total, each one containing 16 variables, in addition to the species. To simplify the example, we will use the following 5 covariables: \n", "- *Wing*: Length (in mm) of primary wing feather from tip to wrist it attaches to. \n", "- *Weight*: Body weight (in gr). \n", "- *Culmen*: Length (in mm) of the upper bill from the tip to where it bumps into the fleshy part of the bird. \n", "- *Hallux*: Length (in mm) of the killing talon. \n", "- *Tail*: Measurement (in mm) related to the length of the tail. \n", "\n", "Also we are going to eliminate the NaNs in the dataset. With these we will predict the \"Species\" of hawks, in other words, these are our dependent variables, the classes we want to predict. " ] }, { "cell_type": "code", "execution_count": 3, "id": "71f3a9bc-979f-44fc-8227-133349e4dfb1", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | Wing | \n", "Weight | \n", "Culmen | \n", "Hallux | \n", "Tail | \n", "Species | \n", "
---|---|---|---|---|---|---|
0 | \n", "385.0 | \n", "920.0 | \n", "25.7 | \n", "30.1 | \n", "219 | \n", "RT | \n", "
2 | \n", "381.0 | \n", "990.0 | \n", "26.7 | \n", "31.3 | \n", "235 | \n", "RT | \n", "
3 | \n", "265.0 | \n", "470.0 | \n", "18.7 | \n", "23.5 | \n", "220 | \n", "CH | \n", "
4 | \n", "205.0 | \n", "170.0 | \n", "12.5 | \n", "14.3 | \n", "157 | \n", "SS | \n", "
5 | \n", "412.0 | \n", "1090.0 | \n", "28.5 | \n", "32.2 | \n", "230 | \n", "RT | \n", "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 155 seconds.\n", "Sampling: [y]\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a9c75e927fb440f6ae09ae520f115714", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n" ], "text/plain": [] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "with model_hawks:\n", " idata = pm.sample(chains=4, compute_convergence_checks=False, random_seed=123)\n", " pm.sample_posterior_predictive(idata, extend_inferencedata=True)" ] }, { "cell_type": "markdown", "id": "fb2e357e-502e-4ac5-9d53-928437bd2a4e", "metadata": {}, "source": [ "## Results \n", "\n", "### Variable Importance \n", "\n", "It may be that some of the input variables are not informative for classifying by species, so in the interest of parsimony and in reducing the computational cost of model estimation, it is useful to quantify the importance of each variable in the dataset. PyMC-BART provides the function {func}`~pymc_bart.plot_variable_importance()`, which generates a plot that shows on his x-axis the number of covariables and on the y-axis the R$^2$ (the square of the Pearson correlation coefficient) between the predictions made for the full model (all variables included) and the restricted models, those with only a subset of the variables. The error bars represent the 94 % HDI from the posterior predictive distribution. " ] }, { "cell_type": "code", "execution_count": 10, "id": "a9d1c616-8c1f-4907-ad5a-adffb290c0c2", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
<xarray.DataArray 'y' ()> Size: 8B\n", "array(96.34186308)
<xarray.DataArray 'y' ()> Size: 8B\n", "array(96.34186308)
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 423 seconds.\n", "Sampling: [y]\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "fd3bc21fab6642fa80eb59bfd081901f", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n" ], "text/plain": [] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "with pm.Model(coords=coords) as model_t:\n", " μ_t = pmb.BART(\"μ\", x_0, y_0, m=50, separate_trees=True, dims=[\"species\", \"n_obs\"])\n", " θ_t = pm.Deterministic(\"θ\", pm.math.softmax(μ_t, axis=0))\n", " y_t = pm.Categorical(\"y\", p=θ_t.T, observed=y_0)\n", " idata_t = pm.sample(chains=4, compute_convergence_checks=False, random_seed=123)\n", " pm.sample_posterior_predictive(idata_t, extend_inferencedata=True)" ] }, { "cell_type": "markdown", "id": "60dc23a9-2351-4502-824a-944e0f454c4c", "metadata": {}, "source": [ "Now we are going to reproduce the same analyses as before. " ] }, { "cell_type": "code", "execution_count": 16, "id": "a05a3d39-307a-4c08-93ec-3a0503ea6c25", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
<xarray.DataArray 'y' ()> Size: 8B\n", "array(97.26565657)
<xarray.DataArray 'y' ()> Size: 8B\n", "array(97.26565657)