Part 8: Seaborn¶
Seaborn is a high-level plotting library built on top of Matplotlib. It gives sane defaults, nice themes, and quick one-liners for common plots (histograms, scatter, line plots, pairplots, …).
We’ll use it to visualize results from our Stack Overflow survey analysis.
import seaborn as sns
import matplotlib.pyplot as plt
sns.set_theme()1. Toy dataset import¶
print(sns.get_dataset_names())
crash_df = sns.load_dataset("car_crashes")
crash_dfsns.histplot(crash_df['not_distracted'])sns.histplot(crash_df['not_distracted'], kde=True)sns.histplot(crash_df['not_distracted'], kde=False, bins=25)2. Jointplot¶
sns.jointplot(x='speeding', y='alcohol', data=crash_df, kind='kde') # or hex or reg3. Pairplot: releationship between each pair of features¶
sns.pairplot(crash_df)peng_df = sns.load_dataset("penguins")
peng_dfsns.pairplot(peng_df, hue='sex')# Many styles and customizations
sns.set_style("white") # or whitegrid or dark or darkgrid or ticks
sns.set_context("notebook") # or paper or talk or poster
sns.despine(left=True, bottom=True) # remove left and bottom spines4. Bar plot: categorical data¶
sns.barplot(x='sex', y='body_mass_g', data=peng_df) # estimator=np.median
sns.barplot(x='sex', y='body_mass_g', data=peng_df, hue='sex')5. Plots on the big dataset¶
import pandas as pd
df = pd.read_csv("data/survey_results_public.csv")df["YearsCode"]
df["YearsCode"].dropna()
df["YearsCode"].value_counts()years = (df["YearsCode"].replace({"Less than 1 year": 0, "More than 50 years": 50})).astype(float)
# Basic plot
sns.histplot(years)
# Bins tuning
sns.histplot(years, bins=10)
# KDE plotting
sns.histplot(years, kde=True)
# Figure size
plt.figure(figsize=(10,4))
sns.histplot(years, bins=5)
plt.show()salaries = df['CompTotal'].astype(float)
list(salaries.nlargest(200))
salaries = salaries.apply(lambda x: 1_000_000 if x > 1_000_000 else x)
# or drop the ones bigger than 1_000_000
salaries = salaries[salaries < 1_000_000]
sns.histplot(salaries)6. Jointplot¶
Pick two numeric variables that make sense together.
If you’ve created the salaries and years (of coding) variables, a simple joint scatter works:
sns.jointplot(x=salaries, y=years, kind="scatter", height=10) # or kde7. Scatter Plot: AI Adoption vs. Average Salary¶
Let’s recreate the final table from Excercise 6.1 – sol6_1 and inspect the correlation between AI adoption and Average salary
df["knows_Python"] = df["LanguageHaveWorkedWith"].str.contains("Python", na=False)
countries = ["Switzerland", "Germany", "United States of America", "France", "Italy"]
sub = df[df["Country"].isin(countries)]
grp = sub.groupby("Country")
pct_ai = grp["AISelect"].apply(lambda s: (s.str.strip().str.lower() == "yes").mean() * 100)
avg_sal = grp["CompTotal"].median()
sol6_1 = pd.concat([pct_ai, avg_sal], axis=1)
sol6_1.columns = ["PctAIUsers", "AvgSalary"]# Ensure Country is a column, not index
sol6_1 = sol6_1.reset_index().rename(columns={"index": "Country"})
plt.figure(figsize=(6, 4))
sns.scatterplot(
data=sol6_1,
x="PctAIUsers", # What to plot on the x axis
y="AvgSalary", # What to plot on the y axis
hue="Country", # Color by country
s=100, # Size of points
palette="deep" # What colors to use for different countries
)
plt.title("AI Adoption (%) vs. Average Salary")
plt.xlabel("Percent AI Users")
plt.ylabel("Average Salary (USD)")
plt.legend(title="Country")
plt.tight_layout()
plt.show()8. Line Plot: Mean & Median Salary for Top 10 Countries¶
Let’s now plot some useful information from the table created in Excercise 6.3 – age_stats
# 'top10' has columns Country, MeanSalary, MedianSalary, NRespondents
age_grp = df.groupby("Age")
median_salary = age_grp["CompTotal"].median().rename("MedianSalary")
num_python = age_grp["knows_Python"].sum().rename("NumKnowsPython")
n_respondents = age_grp["Age"].count().rename("NRespondents")
age_stats = pd.concat(
[median_salary, num_python, n_respondents],
axis=1
)
age_stats.index = age_stats.index.str.replace("Under 18 years old", "0-18 years_old")
age_stats = age_stats.sort_index()
plt.figure(figsize=(8, 4))
sns.lineplot(
data=age_stats,
x="Age",
y="MedianSalary",
marker="o",
label="Median Salary"
)
plt.xticks(rotation=90)
plt.title("Mean vs. Median Salary (Top 10 Countries by Respondents)")
plt.xlabel("Age")
plt.ylabel("Salary")
plt.legend()
plt.tight_layout()
plt.show()plt.figure(figsize=(8, 4))
sns.lineplot(
data=age_stats,
x="Age",
y="NumKnowsPython",
marker="s",
label="# Knows Python"
)
plt.xticks(rotation=90)
plt.title("Number of Python Users by Age")
plt.xlabel("Age")
plt.ylabel("Number of Python Users")
plt.legend()
plt.tight_layout()
plt.show()9. Histogram: Distribution of Python Adoption by Age Bracket¶
(from Exercise 6.3 – age_stats)
# prepare for plotting
age_stats_plot = (
age_stats
.reset_index()
.rename(columns={"index": "Age"})
)
plt.figure(figsize=(8,4))
sns.barplot(
data=age_stats_plot,
x="Age",
y="NRespondents",
hue="Age",
palette="Blues_d",
)
plt.xticks(rotation=45)
plt.title("Number of Respondents per Age Bracket")
plt.xlabel("Age Bracket")
plt.ylabel("Respondent Count")
plt.tight_layout()
plt.show()
Takeaways:
Scatter plots reveal relationships (e.g., higher salary ↔ AI adoption).
Line plots are ideal for comparing metrics across ordered categories.
Histograms show distributions of a single variable.
Pairplots give a quick overview of all pairwise correlations among multiple numeric metrics.