!pip install statsmodels==0.14.6
Run to view results
# ============================================================
# Replicate the analyses and figures in Python
# Data:
# - Contraceptive prevalence (all methods) from your .csv file
# - Fertility rate from the World Bank API
#
# Outputs:
# 1) Mean contraceptive prevalence in 1970
# 2) Std. dev. of contraceptive prevalence in 1970
# 3) Country with highest contraceptive use in 2000
# 4) Country with lowest fertility in 2000
# 5) Correlation between fertility and contraceptive use in 2000
# 6) Linear regression of fertility on contraceptive use in 2000
# 7) Elasticity at the sample mean
# 8) Figures:
# - scatterplot with fitted regression line
# - residual plot
# ============================================================
# Install first if needed:
# pip install pandas numpy matplotlib statsmodels requests
import pandas as pd
import numpy as np
import requests
import statsmodels.api as sm
import matplotlib.pyplot as plt
# ----------------------------
# 1. File paths / settings
# ----------------------------
contra_file = "API_SP.DYN.CONU.ZS_DS2_en_csv_v2_2894.csv"
fertility_indicator = "SP.DYN.TFRT.IN" # Fertility rate, total (births per woman)
# ----------------------------
# 2. Helper functions
# ----------------------------
def normalize_column_name(col):
"""
Convert year-like column names to integers, keep others unchanged.
"""
try:
val = float(col)
if val.is_integer():
return int(val)
return col
except Exception:
return col
def load_contraception_data(filepath):
"""
Load World Bank CSV file for contraceptive prevalence.
Assumes the main data file has 4 metadata rows before the header.
Keeps only actual countries/territories by excluding aggregates.
"""
data = pd.read_csv(filepath, skiprows=4)
data.columns = [normalize_column_name(c) for c in data.columns]
# Drop aggregate regions by requiring a country code
data = data[data["Country Code"].notna()].copy()
# Exclude common World Bank aggregates
aggregate_names = {
"Arab World", "Central Europe and the Baltics", "Caribbean small states",
"East Asia & Pacific", "Early-demographic dividend", "East Asia & Pacific (excluding high income)",
"East Asia & Pacific (IDA & IBRD countries)", "Euro area", "Europe & Central Asia",
"Europe & Central Asia (excluding high income)", "Europe & Central Asia (IDA & IBRD countries)",
"European Union", "Fragile and conflict affected situations", "High income",
"Heavily indebted poor countries (HIPC)", "IBRD only", "IDA & IBRD total", "IDA blend",
"IDA only", "IDA total", "Late-demographic dividend", "Latin America & Caribbean",
"Latin America & Caribbean (excluding high income)", "Latin America & Caribbean (IDA & IBRD countries)",
"Least developed countries: UN classification", "Low income", "Lower middle income",
"Low & middle income", "Middle East & North Africa", "Middle income",
"Middle East & North Africa (excluding high income)", "Middle East & North Africa (IDA & IBRD countries)",
"North America", "OECD members", "Other small states", "Pacific island small states",
"Post-demographic dividend", "Pre-demographic dividend", "Small states",
"South Asia", "South Asia (IDA & IBRD)", "Sub-Saharan Africa",
"Sub-Saharan Africa (excluding high income)", "Sub-Saharan Africa (IDA & IBRD countries)",
"Upper middle income", "World"
}
data = data[~data["Country Name"].isin(aggregate_names)].copy()
keep_cols = [c for c in data.columns if c in ["Country Name", "Country Code"] or isinstance(c, int)]
data = data[keep_cols].copy()
data.rename(columns={
"Country Name": "country_name",
"Country Code": "country_code"
}, inplace=True)
return data
def fetch_world_bank_indicator(indicator_code):
"""
Download one World Bank indicator for all countries via API.
Returns a tidy dataframe with:
country_code, country_name, year, value
"""
url = (
f"https://api.worldbank.org/v2/country/all/indicator/"
f"{indicator_code}?format=json&per_page=20000"
)
r = requests.get(url, timeout=60)
r.raise_for_status()
payload = r.json()
if not isinstance(payload, list) or len(payload) < 2:
raise ValueError("Unexpected World Bank API response format.")
records = payload[1]
rows = []
for item in records:
country_code = item.get("countryiso3code")
country_name = item.get("country", {}).get("value")
year = item.get("date")
value = item.get("value")
if country_code and year is not None:
rows.append({
"country_code": country_code,
"country_name_api": country_name,
"year": int(year),
"value": value
})
return pd.DataFrame(rows)
# ----------------------------
# 3. Load contraception data
# ----------------------------
contra = load_contraception_data(contra_file)
# ----------------------------
# 4. 1970 contraceptive prevalence
# ----------------------------
c1970 = contra[["country_name", "country_code", 1970]].dropna(subset=[1970]).copy()
c1970.rename(columns={1970: "contraceptive_use_1970"}, inplace=True)
mean_1970 = c1970["contraceptive_use_1970"].mean()
sd_1970_sample = c1970["contraceptive_use_1970"].std(ddof=1)
sd_1970_population = c1970["contraceptive_use_1970"].std(ddof=0)
print("=== 1970 CONTRACEPTIVE PREVALENCE ===")
print(f"N countries with data: {len(c1970)}")
print(f"Mean: {mean_1970:.4f}")
print(f"Sample SD: {sd_1970_sample:.4f}")
print(f"Population SD: {sd_1970_population:.4f}")
print()
# ----------------------------
# 5. Highest contraceptive use in 2000
# ----------------------------
c2000 = contra[["country_name", "country_code", 2000]].dropna(subset=[2000]).copy()
c2000.rename(columns={2000: "contraceptive_use"}, inplace=True)
top_contra_2000 = c2000.loc[c2000["contraceptive_use"].idxmax()]
print("=== HIGHEST CONTRACEPTIVE USE IN 2000 ===")
print(f"Country: {top_contra_2000['country_name']}")
print(f"Value: {top_contra_2000['contraceptive_use']:.4f}")
print()
# ----------------------------
# 6. Load fertility data from World Bank API
# ----------------------------
fert = fetch_world_bank_indicator(fertility_indicator)
fert.rename(columns={"value": "fertility"}, inplace=True)
f2000 = fert[(fert["year"] == 2000) & (fert["fertility"].notna())].copy()
valid_codes = set(contra["country_code"])
f2000 = f2000[f2000["country_code"].isin(valid_codes)].copy()
# ----------------------------
# 7. Lowest fertility in 2000
# ----------------------------
low_fert_2000 = f2000.loc[f2000["fertility"].idxmin()]
print("=== LOWEST FERTILITY IN 2000 ===")
print(f"Country: {low_fert_2000['country_name_api']}")
print(f"Fertility rate: {low_fert_2000['fertility']:.4f}")
print()
# ----------------------------
# 8. Merge 2000 data
# ----------------------------
df2000 = c2000.merge(
f2000[["country_code", "fertility"]],
on="country_code",
how="inner"
).dropna(subset=["contraceptive_use", "fertility"]).copy()
print("=== MERGED 2000 SAMPLE ===")
print(f"N countries with both variables: {len(df2000)}")
print()
# ----------------------------
# 9. Correlation
# ----------------------------
corr_2000 = df2000["fertility"].corr(df2000["contraceptive_use"])
print("=== CORRELATION IN 2000 ===")
print(f"Correlation(fertility, contraceptive_use): {corr_2000:.6f}")
print()
# ----------------------------
# 10. Linear regression
# ----------------------------
X = sm.add_constant(df2000["contraceptive_use"])
y = df2000["fertility"]
model = sm.OLS(y, X).fit()
beta_0 = model.params["const"]
beta_1 = model.params["contraceptive_use"]
r_squared = model.rsquared
print("=== REGRESSION RESULTS ===")
print(model.summary())
print()
regression_table = pd.DataFrame({
"Variable": ["Constant", "Contraceptive use (%)"],
"Coefficient": [beta_0, beta_1],
"Std. Error": [model.bse["const"], model.bse["contraceptive_use"]],
"t-stat": [model.tvalues["const"], model.tvalues["contraceptive_use"]],
"p-value": [model.pvalues["const"], model.pvalues["contraceptive_use"]],
})
print("=== COMPACT REGRESSION TABLE ===")
print(regression_table.to_string(index=False))
print(f"\nObservations: {int(model.nobs)}")
print(f"R-squared: {r_squared:.6f}")
print()
# ----------------------------
# 11. Effect of 1 percentage point increase
# ----------------------------
print("=== EFFECT OF +1 PERCENTAGE POINT CONTRACEPTIVE USE ===")
print(f"Estimated effect on fertility: {beta_1:.6f} births per woman")
print()
# ----------------------------
# 12. Elasticity at the sample mean
# ----------------------------
mean_x = df2000["contraceptive_use"].mean()
mean_y = df2000["fertility"].mean()
elasticity = beta_1 * (mean_x / mean_y)
print("=== ELASTICITY AT THE SAMPLE MEAN ===")
print(f"Mean contraceptive use: {mean_x:.6f}")
print(f"Mean fertility: {mean_y:.6f}")
print(f"Elasticity: {elasticity:.6f}")
print()
# ----------------------------
# 13. Save merged data and regression table
# ----------------------------
df2000.to_csv("merged_fertility_contraception_2000.csv", index=False)
regression_table.to_csv("regression_table_2000.csv", index=False)
# ----------------------------
# 14. Figure 1: Scatterplot + fitted line
# ----------------------------
plt.figure(figsize=(8, 6))
plt.scatter(df2000["contraceptive_use"], df2000["fertility"], alpha=0.8)
x_grid = np.linspace(df2000["contraceptive_use"].min(), df2000["contraceptive_use"].max(), 200)
y_hat = beta_0 + beta_1 * x_grid
plt.plot(x_grid, y_hat, linewidth=2)
plt.xlabel("Contraceptive prevalence, any method (% of married women 15–49)")
plt.ylabel("Fertility rate, total (births per woman)")
plt.title("Fertility vs. contraceptive use across countries, 2000")
plt.tight_layout()
plt.savefig("figure1_scatter_regression_2000.png", dpi=300)
plt.show()
# ----------------------------
# 15. Figure 2: Residual plot
# ----------------------------
fitted = model.fittedvalues
residuals = model.resid
plt.figure(figsize=(8, 6))
plt.scatter(fitted, residuals, alpha=0.8)
plt.axhline(0, linewidth=1)
plt.xlabel("Fitted fertility")
plt.ylabel("Residuals")
plt.title("Residual plot: fertility on contraceptive use, 2000")
plt.tight_layout()
plt.savefig("figure2_residuals_2000.png", dpi=300)
plt.show()
# ----------------------------
# 16. Paste-ready regression output
# ----------------------------
print("=== PASTE-READY REGRESSION TABLE ===")
print(f"""
Dependent variable: Fertility rate, total (births per woman)
---------------------------------------------------------
Coef. Std. Err.
---------------------------------------------------------
Contraceptive use (%) {beta_1:>8.4f} {model.bse['contraceptive_use']:>10.4f}
Constant {beta_0:>8.4f} {model.bse['const']:>10.4f}
---------------------------------------------------------
Observations {int(model.nobs):>8}
R-squared {r_squared:>8.4f}
---------------------------------------------------------
""")
Run to view results