AI Cho Mọi Người

AI Cho Mọi Người

Áp dụng cho advertising và Boston dataset

 

 

 

Trong bài này, chúng ta sẽ áp dụng ba biến thể của gradient descent cho bài toán dự đoán doanh thu và tiên đoán giá nhà Boston. Các bạn chú ý cách xử lý dữ liệu cho hai bộ dữ liệu nay.

 

1. Bài toán dự đoán doanh thu

Dữ liệu dự đoán doanh thu theo số tiền cho quảng cáo như sau

TVRadioNewspaperSales
230.137.869.222.1
44.539.345.110.4
17.245.969.312
151.541.358.516.5
180.810.858.417.9
8.748.9757.2
57.532.823.511.8
120.219.611.613.2
8.62.114.8
199.82.621.215.6
66.15.824.212.6
214.724417.4
23.835.165.99.2
97.57.67.213.7
204.132.94619
195.447.752.922.4
67.836.611412.5
281.439.655.824.4
69.220.518.311.3
147.323.919.114.6
218.427.753.418
237.45.123.517.5
13.215.949.65.6
228.316.926.220.5
62.312.618.39.7
262.93.519.517
142.929.312.615
240.116.722.920.9
248.827.122.918.9
70.61640.810.5
292.928.343.221.4
112.917.438.611.9
97.21.53013.2
265.6200.317.4
95.71.47.411.9
290.74.18.517.8
266.943.8525.4
74.749.445.714.7
43.126.735.110.1
22837.73221.5
202.522.331.616.6
17733.438.717.1
293.627.71.820.7
206.98.426.417.9
25.125.743.38.5
175.122.531.516.1
89.79.935.710.6
239.941.518.523.2
227.215.849.919.8
66.911.736.89.7
199.83.134.616.4
100.49.63.610.7
216.441.739.622.6
182.646.258.721.2
262.728.815.920.2
198.949.46023.7
7.328.141.45.5
136.219.216.613.2
210.849.637.723.8
210.729.59.318.4
53.5221.48.1
261.342.754.724.2
239.315.527.320.7
102.729.68.414
131.142.828.916
699.30.911.3
31.524.62.211
139.314.510.213.4
237.427.51118.9
216.843.927.222.3
199.130.638.718.3
109.814.331.712.4
26.83319.38.8
129.45.731.311
213.424.613.117
16.943.789.48.7
27.51.620.76.9
120.528.514.214.2
5.429.99.45.3
1167.723.111
76.426.722.311.8
239.84.136.917.3
75.320.332.511.3
68.444.535.613.6
213.54333.821.7
193.218.465.720.2
76.327.51612
110.740.663.216
88.325.573.412.9
109.847.851.416.7
134.34.99.314
28.61.5337.3
217.733.55919.4
250.936.572.322.2
107.41410.911.5
163.331.652.916.9
197.63.55.916.7
184.9212220.5
289.742.351.225.4
135.241.745.917.2
222.44.349.816.7
296.436.3100.923.8
280.210.121.419.8
187.917.217.919.7
238.234.35.320.7
137.946.45915
251129.77.2
90.40.323.212
13.10.425.65.3
255.426.95.519.8
225.88.256.518.4
241.73823.221.8
175.715.42.417.1
209.620.610.720.9
78.246.834.514.6
75.13552.712.6
139.214.325.612.2
76.40.814.89.4
125.736.979.215.9
19.41622.36.6
141.326.846.215.5
18.821.750.47
2242.415.616.6
123.134.612.415.2
229.532.374.219.7
87.211.825.910.6
7.838.950.66.6
80.209.211.9
220.3493.224.7
59.61243.19.7
0.739.68.71.6
265.22.94317.7
8.427.22.15.7
219.833.545.119.6
36.938.665.610.8
48.3478.511.6
25.6399.39.5
273.728.959.720.8
4325.920.59.6
184.943.91.720.7
73.41712.910.9
193.735.475.619.2
220.533.237.920.1
104.65.734.410.4
96.214.838.912.3
140.31.9910.3
240.17.38.718.2
243.24944.325.4
3840.311.910.9
44.725.820.610.1
280.713.93716.1
1218.448.711.6
197.623.314.216.6
171.339.737.716
187.821.19.520.6
4.111.65.73.2
93.943.550.515.3
149.81.324.310.1
11.736.945.27.3
131.718.434.612.9
172.518.130.716.4
85.735.849.313.3
188.418.125.619.9
163.536.87.418
117.214.75.411.9
234.53.484.816.9
17.937.621.68
206.85.219.417.2
215.423.657.617.1
284.310.66.420
5011.618.48.4
164.520.947.417.5
19.620.1177.6
168.47.112.816.7
222.43.413.116.5
276.948.941.827
248.430.220.320.2
170.27.835.216.7
276.72.323.716.8
165.61017.617.6
156.62.68.315.5
218.55.427.417.2
56.25.729.78.7
287.64371.826.2
253.821.33017.6
20545.119.622.6
139.52.126.610.3
191.128.718.217.3
28613.93.720.9
18.712.123.46.7
39.541.15.810.8
75.510.8611.9
17.24.131.65.9
166.8423.619.6
149.735.6617.3
38.23.713.87.6
94.24.98.114
1779.36.414.8
283.64266.225.5
232.18.68.718.4

Đọc và xử lý data. Các đặc trưng trong bộ dữ liệu có cùng đơn vị (số tiền) nên chúng ta có thể chuẩn toàn bộ các đặc trưng một lúc (không cần làm theo từng đặc trưng).

# dataset
import numpy as np
from numpy import genfromtxt
import matplotlib.pyplot as plt

data = genfromtxt('advertising.csv', delimiter=',', skip_header=1)

m = data.shape[0]
X = data[:,:3]
y = data[:,3:]

maxi = np.max(X)
mini = np.min(X)
avg = np.mean(X)
X = (X-avg) / (maxi-mini)

X_b = np.c_[np.ones((m, 1)), X]

 

Huấn luyện theo SGD

def stochastic_gradient_descent():
    n_epochs = 50
    learning_rate = 0.01
    
    # khởi tạo giá trị tham số
    thetas = np.random.randn(4, 1)
    
    thetas_path = [thetas]
    losses = []
    
    for epoch in range(n_epochs):
        for i in range(m):
            # lấy ngẫu nhiên 1 sample
            random_index = np.random.randint(m)
            xi = X_b[random_index:random_index+1]
            yi = y[random_index:random_index+1]
            
            # tính output 
            oi = xi.dot(thetas)
            
            # tính loss li
            li = (oi - yi)*(oi - yi) / 2
            
            # tính gradient cho loss
            g_li = (oi - yi)
            
            # tính gradient 
            gradients = xi.T.dot(g_li)
                        
            # update giá trị theta
            thetas = thetas - learning_rate*gradients
            
            # logging
            thetas_path.append(thetas)            
            losses.append(li[0][0])

    return thetas_path, losses

bgd_thetas, losses = stochastic_gradient_descent()

# in loss cho 500 sample đầu
x_axis = list(range(500))
plt.plot(x_axis,losses[:500], color="r")
plt.show()

 

Giá trị loss cho 500 lần cập nhật đầu tiên

 

Huấn luyện theo MBGD

def mini_batch_gradient_descent():
    n_iterations = 50
    minibatch_size = 20
    
    thetas = np.random.randn(4, 1)
    thetas_path = [thetas]    
    losses = []
    
    for epoch in range(n_iterations):
        shuffled_indices = np.random.permutation(m)
        X_b_shuffled = X_b[shuffled_indices]
        y_shuffled = y[shuffled_indices]
                
        for i in range(0, m, minibatch_size):
            xi = X_b_shuffled[i:i+minibatch_size]
            yi = y_shuffled[i:i+minibatch_size]
            
            # tính output 
            output = xi.dot(thetas)
            
            # tính loss
            loss = (output - yi)**2
            
            # tính đạo hàm cho loss
            loss_grd = 2*(output - yi)/minibatch_size
            
            # tính đạo hàm cho các tham số
            gradients = xi.T.dot(loss_grd)
            
            # cập nhật tham số
            learning_rate = 0.01
            thetas = thetas - learning_rate*gradients
            thetas_path.append(thetas)
            
            loss_mean = np.sum(loss)/minibatch_size
            losses.append(loss_mean)

    return thetas_path, losses

mbgd_thetas, losses = mini_batch_gradient_descent()

# in loss cho 200 sample đầu
x_axis = list(range(200))
plt.plot(x_axis,losses[:200], color="r")
plt.show()

 

Kết quả loss cho 200 mini-batch đầu tiên

 

Huấn luyện theo BGD

def batch_gradient_descent():
    n_iterations = 100
    learning_rate = 0.01
    
    # khởi tạo giá trị tham số
    thetas = np.random.randn(4, 1)
    thetas_path = [thetas]
    losses = []
    
    for i in range(n_iterations):
        # tính output
        output = X_b.dot(thetas)
        
        # tính loss
        loss = (output - y)**2        
                
        # tính đạo hàm cho loss
        loss_grd = 2*(output - y)/m
        
        # tính đạo hàm cho các tham số
        gradients = X_b.T.dot(loss_grd)
        
        # cập nhật tham số
        thetas = thetas - learning_rate*gradients
        thetas_path.append(thetas)
        
        mean_loss = np.sum(loss)/m
        losses.append(mean_loss)

    return thetas_path, losses

bgd_thetas, losses = batch_gradient_descent()

# in loss cho 100 sample đầu
x_axis = list(range(100))
plt.plot(x_axis,losses[:100], color="r")
plt.show()

 

Giá trị loss cho từng epoch

 

 

2. Bài toán tiên đoán giá nhà Boston

Dự liệu nhà Boston được hiển thị ở bảng sau

crimzninduschasnoxrmagedisradtaxptratioblacklstatmedv
0.00632182.3100.5386.57565.24.09129615.3396.94.9824
0.0273107.0700.4696.42178.94.9671224217.8396.99.1421.6
0.0323702.1800.4586.99845.86.0622322218.7394.632.9433.4
0.0690502.1800.4587.14754.26.0622322218.7396.95.3336.2
0.0882912.57.8700.5246.01266.65.5605531115.2395.612.4322.9
0.2248912.57.8700.5246.37794.36.3467531115.2392.5220.4515
0.1174712.57.8700.5246.00982.96.2267531115.2396.913.2718.9
0.0937812.57.8700.5245.889395.4509531115.2390.515.7121.7
0.6297608.1400.5385.94961.84.7075430721396.98.2620.4
0.6379608.1400.5386.09684.54.4619430721380.0210.2618.2
0.6273908.1400.5385.83456.54.4986430721395.628.4719.9
1.0539308.1400.5385.93529.34.4986430721386.856.5823.1
0.8027108.1400.5385.45636.63.7965430721288.9911.6920.2
1.2517908.1400.5385.5798.13.7979430721376.5721.0213.6
0.8520408.1400.5385.96589.24.0123430721392.5313.8319.6
1.2324708.1400.5386.14291.73.9769430721396.918.7215.2
0.9884308.1400.5385.8131004.0952430721394.5419.8814.5
0.9557708.1400.5386.04788.84.4534430721306.3817.2814.8
1.1308108.1400.5385.71394.14.233430721360.1722.612.7
1.3547208.1400.5386.0721004.175430721376.7313.0414.5
1.6128208.1400.5386.09696.93.7598430721248.3120.3413.5
0.1750505.9600.4995.96630.23.8473527919.2393.4310.1324.7
0.02763752.9500.4286.59521.85.4011325218.3395.634.3230.8
0.03359752.9500.4287.02415.85.4011325218.3395.621.9834.9
0.141506.9100.4486.1696.65.7209323317.9383.375.8125.3
0.1593606.9100.4486.2116.55.7209323317.9394.467.4424.7
0.1226906.9100.4486.069405.7209323317.9389.399.5521.2
0.1714206.9100.4485.68233.85.1004323317.9396.910.2119.3
0.1883606.9100.4485.78633.35.1004323317.9396.914.1520
0.2292706.9100.4486.0385.55.6894323317.9392.7418.816.6
0.2197706.9100.4485.602626.0877323317.9396.916.219.4
0.08873215.6400.4395.96345.76.8147424316.8395.5613.4519.7
0.04337215.6400.4396.115636.8147424316.8393.979.4320.5
0.04981215.6400.4395.99821.46.8147424316.8396.98.4323.4
0.013675400.415.88847.67.3197346921.1396.914.818.9
0.01311901.2200.4037.24921.98.6966522617.9395.934.8135.4
0.02055850.7400.416.38335.79.1876231317.3396.95.7724.7
0.014321001.3200.4116.81640.58.3248525615.1392.93.9531.6
0.15445255.1300.4536.14529.27.8148828419.7390.686.8623.3
0.14932255.1300.4535.74166.27.2254828419.7395.1113.1518.7
0.17171255.1300.4535.96693.46.8185828419.7378.0814.4416
0.1265255.1300.4536.76243.47.9809828419.7395.589.525
0.0195117.51.3800.41617.10459.59.2229321618.6393.248.0533
0.03584803.3700.3986.2917.86.6115433716.1396.94.6723.5
0.04379803.3700.3985.78731.16.6115433716.1396.910.2419.4
0.0578912.56.0700.4095.87821.46.498434518.9396.218.122
0.1355412.56.0700.4095.59436.86.498434518.9396.913.0917.4
0.08826010.8100.4136.4176.65.2873430519.2383.736.7224.2
0.09164010.8100.4136.0657.85.2873430519.2390.915.5222.8
0.19539010.8100.4136.2456.25.2873430519.2377.177.5423.4
0.07896012.8300.4376.27364.2515539818.7394.926.7824.1
0.09512012.8300.4376.286454.5026539818.7383.238.9421.4
0.10153012.8300.4376.27974.54.0522539818.7373.6611.9720
0.08707012.8300.4376.1445.84.0905539818.7386.9610.2720.8
0.04113254.8600.4266.72733.55.4007428119396.95.2928
0.04462254.8600.4266.61970.45.4007428119395.637.2223.9
0.03551254.8600.4266.16746.75.4007428119390.647.5122.9
0.0505904.4900.4496.389484.7794324718.5396.99.6223.9
0.0573504.4900.4496.6356.14.4377324718.5392.36.5326.6
0.0518804.4900.4496.01545.14.4272324718.5395.9912.8622.5
0.0715104.4900.4496.12156.83.7476324718.5395.158.4422.2
0.056603.4100.4897.00786.33.4217227017.8396.95.523.6
0.0530203.4100.4897.07963.13.4145227017.8396.065.728.7
0.0468403.4100.4896.41766.13.0923227017.8392.188.8122.6
0.028752815.0400.4646.21128.93.6659427018.2396.336.2125
0.042942815.0400.4646.24977.33.615427018.2396.910.5920.6
0.1150402.8900.4456.16369.63.4952227618391.8311.3421.4
0.1486608.5600.526.72779.92.7778538420.9394.769.4227.5
0.1143208.5600.526.78171.32.8561538420.9395.587.6726.5
0.2287608.5600.526.40585.42.7147538420.970.810.6318.6
0.2116108.5600.526.13787.42.7147538420.9394.4713.4419.3
0.171208.5600.525.83691.92.211538420.9395.6718.6619.5
0.1311708.5600.526.12785.22.1224538420.9387.6914.0920.4
0.1280208.5600.526.47497.12.4329538420.9395.2412.2719.8
0.2636308.5600.526.22991.22.5451538420.9391.2315.5519.4
0.10084010.0100.5476.71581.62.6775643217.8395.5910.1622.8
0.14231010.0100.5476.25484.22.2565643217.8388.7410.4518.5
0.13158010.0100.5476.17672.52.7301643217.8393.312.0421.2
0.15098010.0100.5476.02182.62.7474643217.8394.5110.319.2
0.13058010.0100.5475.87273.12.4775643217.8338.6315.3720.4
0.14476010.0100.5475.73165.22.7592643217.8391.513.6119.3
0.06899025.6500.5815.8769.72.2577218819.1389.1514.3722
0.07165025.6500.5816.00484.12.1974218819.1377.6714.2720.3
0.09299025.6500.5815.96192.92.0869218819.1378.0917.9320.5
0.15038025.6500.5815.856971.9444218819.1370.3125.4117.3
0.09849025.6500.5815.87995.82.0063218819.1379.3817.5818.8
0.38735025.6500.5815.61395.61.7572218819.1359.2927.2615.7
0.25915021.8900.6245.693961.7883443721.2392.1117.1916.2
0.32543021.8900.6246.43198.81.8125443721.2396.915.3918
1.19294021.8900.6246.32697.72.271443721.2396.912.2619.6
0.32982021.8900.6245.82295.42.4699443721.2388.6915.0318.4
0.97617021.8900.6245.75798.42.346443721.2262.7617.3115.6
0.32264021.8900.6245.94293.51.9669443721.2378.2516.917.4
0.35233021.8900.6246.45498.41.8498443721.2394.0814.5917.1
0.2498021.8900.6245.85798.21.6686443721.2392.0421.3213.3
0.54452021.8900.6246.15197.91.6687443721.2396.918.4617.8
1.62864021.8900.6245.0191001.4394443721.2396.934.4114.4
3.32105019.5810.8715.4031001.3216540314.7396.926.8213.4
2.37934019.5800.8716.131001.4191540314.7172.9127.813.8
2.36862019.5800.8714.92695.71.4608540314.7391.7129.5314.6
2.33099019.5800.8715.18693.81.5296540314.7356.9928.3217.8
2.73397019.5800.8715.59794.91.5257540314.7351.8521.4515.4
1.6566019.5800.8716.12297.31.618540314.7372.814.121.5
2.14918019.5800.8715.70998.51.6232540314.7261.9515.7919.4
1.41385019.5810.8716.129961.7494540314.7321.0215.1217
2.44668019.5800.8715.272941.7364540314.788.6316.1413.1
1.34284019.5800.6056.0661001.7573540314.7353.896.4324.3
1.42502019.5800.8716.511001.7659540314.7364.317.3923.3
1.27346019.5810.6056.2592.61.7984540314.7338.925.527
1.46336019.5800.6057.48990.81.9709540314.7374.431.7350
1.51902019.5810.6058.37593.92.162540314.7388.453.3250
2.24236019.5800.6055.85491.82.422540314.7395.1111.6422.7
2.924019.5800.6056.101932.2834540314.7240.169.8125
2.01019019.5800.6057.92996.22.0459540314.7369.33.750
1.80028019.5800.6055.87779.22.4259540314.7227.6112.1423.8
2.44953019.5800.6056.40295.22.2625540314.7330.0411.3222.3
1.20742019.5800.6055.87594.62.4259540314.7292.2914.4317.4
2.3139019.5800.6055.8897.32.3887540314.7348.1312.0319.1
0.1391404.0500.515.57288.52.5961529616.6396.914.6923.1
0.0917804.0500.516.41684.12.6463529616.6395.59.0423.6
0.0844704.0500.515.85968.72.7019529616.6393.239.6422.6
0.0666404.0500.516.54633.13.1323529616.6390.965.3329.4
0.0702204.0500.516.0247.23.5549529616.6393.2310.1123.2
0.0542504.0500.516.31573.43.3175529616.6395.66.2924.6
0.0664204.0500.516.8674.42.9153529616.6391.276.9229.9
0.057802.4600.4886.9858.42.829319317.8396.95.0437.2
0.0658802.4600.4887.76583.32.741319317.8395.567.5639.8
0.0688802.4600.4886.14462.22.5979319317.8396.99.4536.2
0.0910302.4600.4887.15592.22.7006319317.8394.124.8237.9
0.1000802.4600.4886.56395.62.847319317.8396.95.6832.5
0.0560202.4600.4887.83153.63.1992319317.8392.634.4550
0.07875453.4400.4376.78241.13.7886539815.2393.876.6832
0.0837453.4400.4377.18538.94.5667539815.2396.95.3934.9
0.09068453.4400.4376.95121.56.4798539815.2377.685.137
0.06911453.4400.4376.73930.86.4798539815.2389.714.6930.5
0.08664453.4400.4377.17826.36.4798539815.2390.492.8736.4
0.02187602.9300.4016.89.96.2196126515.6393.375.0331.1
0.01439602.9300.4016.60418.86.2196126515.6376.74.3829.1
0.04666801.5200.4047.10736.67.309232912.6354.318.6130.3
0.01778951.4700.4037.13513.97.6534340217384.34.4532.9
0.0344582.52.0300.4156.16238.46.27234814.7393.777.4324.1
0.0351952.6800.41617.85333.25.118422414.7392.783.8148.5
0.02009952.6800.41618.03431.95.118422414.7390.552.8850
0.13642010.5900.4895.89122.33.9454427718.6396.910.8722.6
0.22969010.5900.4896.32652.54.3549427718.6394.8710.9724.4
0.13587010.5910.4896.06459.14.2392427718.6381.3214.6624.4
0.37578010.5910.4895.40488.63.665427718.6395.2423.9819.3
0.14052010.5900.4896.37532.33.9454427718.6385.819.3828.1
0.28955010.5900.4895.4129.83.5875427718.6348.9329.5523.7
0.0456013.8910.555.888563.1121527616.4392.813.5123.3
0.4077106.210.5076.16491.33.048830717.4395.2421.4621.7
0.6235606.210.5076.87977.73.2721830717.4390.399.9327.5
0.614706.200.5076.61880.83.2721830717.4396.97.630.1
0.3153306.200.5048.26678.32.8944830717.4385.054.1444.8
0.5269306.200.5048.725832.8944830717.43824.6350
0.3821406.200.5048.0486.53.2157830717.4387.383.1337.6
0.4123806.200.5047.16379.93.2157830717.4372.086.3631.6
0.4417806.200.5046.55221.43.3751830717.4380.343.7631.5
0.53706.200.5045.98168.13.6715830717.4378.3511.6524.3
0.5752906.200.5078.33773.33.8384830717.4385.912.4741.7
0.3314706.200.5078.24770.43.6519830717.4378.953.9548.3
0.4479106.210.5076.72666.53.6519830717.4360.28.0529
0.3304506.200.5076.08661.53.6519830717.4376.7510.8824
0.5205806.210.5076.63176.54.148830717.4388.459.5425.1
0.11329304.9300.4286.89754.36.3361630016.6391.2511.3822
0.1029304.9300.4286.35852.97.0355630016.6372.7511.2222.2
0.12757304.9300.4286.3937.87.0355630016.6374.715.1923.7
0.20608225.8600.4315.59376.57.9549733019.1372.4912.517.6
0.33983225.8600.4316.10834.98.0555733019.1390.189.1624.3
0.16439225.8600.4316.43349.17.8265733019.1374.719.5224.5
0.19073225.8600.4316.71817.57.8265733019.1393.746.5626.2
0.1403225.8600.4316.487137.3967733019.1396.285.924.4
0.21409225.8600.4316.4388.97.3967733019.1377.073.5924.8
0.36894225.8600.4318.2598.48.9067733019.1396.93.5442.8
0.54011203.9700.6477.20381.82.1121526413392.89.5933.8
0.53412203.9700.6477.5289.42.1398526413388.377.2643.1
0.52014203.9700.6478.39891.52.2885526413386.865.9148.8
0.82526203.9700.6477.32794.52.0788526413393.4211.2531
0.55007203.9700.6477.20691.61.9301526413387.898.136.5
0.76162203.9700.6475.5662.81.9865526413392.410.4522.8
0.7857203.9700.6477.01484.62.1329526413384.0714.7930.7
0.5405203.9700.5757.4752.62.872526413390.33.1643.5
0.16211206.9600.4646.2416.34.429322318.6396.96.5925.2
0.1146206.9600.4646.53858.73.9175322318.6394.967.7324.4
0.22188206.9610.4647.69151.84.3665322318.6390.776.5835.2
0.05644406.4110.4476.75832.94.0776425417.6396.93.5332.4
0.21038203.3300.44296.81232.24.1007521614.9396.94.8535.1
0.03705203.3300.44296.96837.25.2447521614.9392.234.5935.4
0.06129203.3310.44297.64549.75.2119521614.9377.073.0146
0.01501901.2110.4017.92324.85.885119813.6395.523.1650
0.00906902.9700.47.08820.87.3073128515.3394.727.8532.2
0.01096552.2500.3896.45331.97.3073130015.3394.728.2322
0.01965801.7600.3856.2331.59.0892124118.2341.612.9320.1
0.045952.55.3200.4056.31545.67.3172629316.6396.97.622.3
0.03502804.9500.4116.86127.95.1167424519.2396.93.3328.5
0.03615804.9500.4116.6323.45.1167424519.2396.94.727.9
0.08265013.9200.4376.12718.45.5027428916396.98.5823.9
0.05372013.9200.4376.549515.9604428916392.857.3927.1
0.14103013.9200.4375.79586.32428916396.915.8420.3
0.03537346.0900.4336.5940.45.4917732916.1395.759.522
0.09266346.0900.4336.49518.45.4917732916.1383.618.6726.4
0.1346.0900.4336.98217.75.4917732916.1390.434.8633.1
0.05515332.1800.4727.23641.14.022722218.4393.686.9336.1
0.05479332.1800.4726.61658.13.37722218.4393.368.9328.4
0.07503332.1800.4727.4271.93.0992722218.4396.96.4733.4
0.4929809.900.5446.63582.53.3175430418.4396.94.5422.8
0.349409.900.5445.97276.73.1025430418.4396.249.9720.3
2.6354809.900.5444.97337.82.5194430418.4350.4512.6416.1
0.7904109.900.5446.12252.82.6403430418.4396.95.9822.1
0.2616909.900.5446.02390.42.834430418.4396.311.7219.4
0.2535609.900.5445.70577.73.945430418.4396.4211.516.2
0.3182709.900.5445.91483.23.9986430418.4390.718.3317.8
0.2452209.900.5445.78271.74.0317430418.4396.915.9419.8
0.4020209.900.5446.38267.23.5325430418.4395.2110.3623.1
0.167607.3800.4936.42652.34.5404528719.6396.97.223.8
0.3410907.3800.4936.41540.14.7211528719.6396.96.1225
0.1918607.3800.4936.43114.75.4159528719.6393.685.0824.6
0.2410307.3800.4936.08343.75.4159528719.6396.912.7922.2
0.0661703.2400.465.86825.85.2146443016.9382.449.9719.3
0.0454403.2400.466.14432.25.8736443016.9368.579.0919.8
0.0508305.1900.5156.31638.16.4584522420.2389.715.6822.2
0.0373805.1900.5156.3138.56.4584522420.2389.46.7520.7
0.0342705.1900.5155.86946.35.2311522420.2396.99.819.5
0.0330605.1900.5156.05937.34.8122522420.2396.148.5120.6
0.0549705.1900.5155.98545.44.8122522420.2396.99.7419
0.0615105.1900.5155.96858.54.8122522420.2396.99.2918.7
0.01301351.5200.4427.24149.37.0379128415.5394.745.4932.7
0.0249801.8900.5186.5459.76.2669142215.9389.968.6516.5
0.02543553.7800.4846.69656.45.7321537017.6396.97.1823.9
0.03049553.7800.4846.87428.16.4654537017.6387.974.6131.2
0.0187854.1500.4296.51627.78.5353435117.9392.436.3623.1
0.01501802.0100.4356.63529.78.344428017390.945.9924.5
0.02899401.2500.4296.93934.58.7921133519.7389.855.8926.6
0.07244601.6900.4115.88418.510.7103441118.3392.337.7918.6
8.98296018.110.776.21297.42.12222466620.2377.7317.617.8
3.8497018.110.776.395912.50522466620.2391.3413.2721.7
5.20177018.110.776.12783.42.72272466620.2395.4311.4822.7
4.26131018.100.776.11281.32.50912466620.2390.7412.6722.6
4.54192018.100.776.398882.51822466620.2374.567.7925
3.67822018.100.775.36296.22.10362466620.2380.7910.1920.8
4.55587018.100.7183.56187.91.61322466620.2354.77.1227.5
3.69695018.100.7184.96391.41.75232466620.2316.031421.9
13.5222018.100.6313.8631001.51062466620.2131.4213.3323.1
4.89822018.100.6314.971001.33252466620.2375.523.2650
6.53876018.110.6317.01697.51.20242466620.2392.052.9650
9.2323018.100.6316.2161001.16912466620.2366.159.5350
8.26725018.110.6685.87589.61.12962466620.2347.888.8850
11.1081018.100.6684.9061001.17422466620.2396.934.7713.8
18.4982018.100.6684.1381001.1372466620.2396.937.9713.8
15.288018.100.6716.64993.31.34492466620.2363.0223.2413.9
9.82349018.100.6716.79498.81.3582466620.2396.921.2413.3
9.18702018.100.75.5361001.58042466620.2396.923.611.3
7.99248018.100.75.521001.53312466620.2396.924.5612.3
20.0849018.100.74.36891.21.43952466620.2285.8330.638.8
24.3938018.100.74.6521001.46722466620.2396.928.2810.5
22.5971018.100.7589.51.51842466620.2396.931.997.4
8.15174018.100.75.3998.91.72812466620.2396.920.8511.5
5.29305018.100.76.05182.52.16782466620.2378.3818.7623.2
11.5779018.100.75.036971.772466620.2396.925.689.7
13.3598018.100.6935.88794.71.78212466620.2396.916.3512.7
5.87205018.100.6936.405961.67682466620.2396.919.3712.5
38.3518018.100.6935.4531001.48962466620.2396.930.595
25.0461018.100.6935.9871001.58882466620.2396.926.775.6
14.2362018.100.6936.3431001.57412466620.2396.920.327.2
24.8017018.100.6935.349961.70282466620.2396.919.778.3
11.9511018.100.6595.6081001.28522466620.2332.0912.1327.9
7.40389018.100.5975.61797.91.45472466620.2314.6426.417.2
28.6558018.100.5975.1551001.58942466620.2210.9720.0816.3
45.7461018.100.6934.5191001.65822466620.288.2736.987
18.0846018.100.6796.4341001.83472466620.227.2529.057.2
25.9406018.100.6795.30489.11.64752466620.2127.3626.6410.4
73.5341018.100.6795.9571001.80262466620.216.4520.628.8
11.8123018.100.7186.82476.51.7942466620.248.4522.748.4
8.79212018.100.5845.56570.62.06352466620.23.6517.1611.7
15.8603018.100.6795.89695.41.90962466620.27.6824.398.3
37.6619018.100.6796.20278.71.86292466620.218.8214.5210.9
7.36711018.100.6796.19378.11.93562466620.296.7321.5211
9.33889018.100.6796.3895.61.96822466620.260.7224.089.5
10.0623018.100.5846.83394.32.08822466620.281.3319.6914.1
6.44405018.100.5846.42574.82.20042466620.297.9512.0316.1
5.58107018.100.7136.43687.92.31582466620.2100.1916.2214.3
13.9134018.100.7136.208952.22222466620.2100.6315.1711.7
15.1772018.100.746.1521001.91422466620.29.3226.458.7
9.39063018.100.745.62793.91.81722466620.2396.922.8812.8
22.0511018.100.745.81892.41.86622466620.2391.4522.1110.5
9.72418018.100.746.40697.22.06512466620.2385.9619.5217.1
5.66637018.100.746.2191002.00482466620.2395.6916.5918.4
9.96654018.100.746.4851001.97842466620.2386.7318.8515.4
12.8023018.100.745.85496.61.89562466620.2240.5223.7910.8
10.6718018.100.746.45994.81.98792466620.243.0623.9811.8
9.92485018.100.746.25196.62.1982466620.2388.5216.4412.6
9.32909018.100.7136.18598.72.26162466620.2396.918.1314.1
5.44114018.100.7136.65598.22.35522466620.2355.2917.7315.2
5.09017018.100.7136.29791.82.36822466620.2385.0917.2716.1
8.24809018.100.7137.39399.32.45272466620.2375.8716.7417.8
4.75237018.100.7136.52586.52.43582466620.250.9218.1314.1
8.20058018.100.7135.93680.32.77922466620.23.516.9413.5
7.75223018.100.7136.30183.72.78312466620.2272.2116.2314.9
6.80117018.100.7136.08184.42.71752466620.2396.914.720
4.81213018.100.7136.701902.59752466620.2255.2316.4216.4
3.69311018.100.7136.37688.42.56712466620.2391.4314.6517.7
6.65492018.100.7136.317832.73442466620.2396.913.9919.5
5.82115018.100.7136.51389.92.80162466620.2393.8210.2920.2
7.83932018.100.6556.20965.42.96342466620.2396.913.2221.4
3.1636018.100.6555.75948.23.06652466620.2334.414.1319.9
3.77498018.100.6555.95284.72.87152466620.222.0117.1519
4.42228018.100.5846.00394.52.54032466620.2331.2921.3219.1
15.5757018.100.585.926712.90842466620.2368.7418.1319.1
13.0751018.100.585.71356.72.82372466620.2396.914.7620.1
4.03841018.100.5326.22990.73.09932466620.2395.3312.8719.6
3.56868018.100.586.437752.89652466620.2393.3714.3623.2
8.05579018.100.5845.42795.42.42982466620.2352.5818.1413.8
4.87141018.100.6146.48493.62.30532466620.2396.2118.6816.7
15.0234018.100.6145.30497.32.10072466620.2349.4824.9112
10.233018.100.6146.18596.72.17052466620.2379.718.0314.6
14.3337018.100.6146.229881.95122466620.2383.3213.1121.4
5.82401018.100.5326.24264.73.42422466620.2396.910.7423
5.70818018.100.5326.7574.93.33172466620.2393.077.7423.7
2.81838018.100.5325.76240.34.09832466620.2392.9210.4221.8
2.37857018.100.5835.87141.93.7242466620.2370.7313.3420.6
5.69175018.100.5836.11479.83.54592466620.2392.6814.9819.1
4.83567018.100.5835.90553.23.15232466620.2388.2211.4520.6
0.15086027.7400.6095.45492.71.8209471120.1395.0918.0615.2
0.20746027.7400.6095.093981.8226471120.1318.4329.688.1
0.10574027.7400.6095.98398.81.8681471120.1390.1118.0713.6
0.11132027.7400.6095.98383.52.1099471120.1396.913.3520.1
0.1733109.6900.5855.707542.3817639119.2396.912.0121.8
0.2683809.6900.5855.79470.62.8927639119.2396.914.118.3
0.1778309.6900.5855.56973.52.3999639119.2395.7715.117.5
0.06263011.9300.5736.59369.12.4786127321391.999.6722.4
0.04527011.9300.5736.1276.72.2875127321396.99.0820.6
0.06076011.9300.5736.976912.1675127321396.95.6423.9
0.04741011.9300.5736.0380.82.505127321396.97.8811.9

 

Đọc và xử lý data. Ở bộ dữ liệu này, chúng ta cần chuẩn hóa cho từng đặc trưng (từng cột dữ liệu đặc trưng)

# data
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt


data = pd.read_csv('BostonHousing.csv')

def normal(x:list):
    maxi = max(x)
    mini = min(x)
    avg = np.mean(x)
    new = [(i-avg)/(maxi-mini) for i in x ]
    
    return new

df = data.copy()
df = df.apply(normal, axis=0)

Xd = df.drop(columns=['medv'])
Xd.insert(0, 'X0', 1) # bias 

# numpy array format
y = df.medv.values
y = np.expand_dims(y, axis=1)

X_b = Xd.values

# sample size
m = len(df.index)
n = X.shape[1]
theta = np.ones(n)

 

Huấn luyện theo SGD

def stochastic_gradient_descent():
    n_epochs = 50
    learning_rate = 0.01
    
    # khởi tạo giá trị tham số
    thetas = np.random.randn(14, 1)
    
    thetas_path = [thetas]
    losses = []
    
    for epoch in range(n_epochs):
        for i in range(m):
            # lấy ngẫu nhiên 1 sample
            random_index = np.random.randint(m)
            xi = X_b[random_index:random_index+1]
            yi = y[random_index:random_index+1]
            
            # tính output 
            oi = xi.dot(thetas)
            
            # tính loss li
            li = (oi - yi)*(oi - yi) / 2
            
            # tính gradient cho loss
            g_li = (oi - yi)
            
            # tính gradient 
            gradients = xi.T.dot(g_li)
                        
            # update giá trị theta
            thetas = thetas - learning_rate*gradients
            
            # logging
            thetas_path.append(thetas)            
            losses.append(li[0][0])

    return thetas_path, losses

bgd_thetas, losses = stochastic_gradient_descent()

# in loss cho 100 sample đầu
x_axis = list(range(100))
plt.plot(x_axis,losses[:100], color="r")
plt.show()

 

Giá trị loss cho 100 lần cập nhật đầu tiên

 

Huấn luyện theo MBGD

def mini_batch_gradient_descent():
    n_iterations = 200
    minibatch_size = 64
    
    thetas = np.random.randn(14, 1)
    thetas_path = [thetas]    
    losses = []
    
    for epoch in range(n_iterations):
        shuffled_indices = np.random.permutation(m)
        X_b_shuffled = X_b[shuffled_indices]
        y_shuffled = y[shuffled_indices]
                
        for i in range(0, m, minibatch_size):
            xi = X_b_shuffled[i:i+minibatch_size]
            yi = y_shuffled[i:i+minibatch_size]
            
            # tính output 
            output = xi.dot(thetas)
            
            # tính loss
            loss = (output - yi)**2
            
            # tính đạo hàm cho loss
            loss_grd = 2*(output - yi)/minibatch_size
            
            # tính đạo hàm cho các tham số
            gradients = xi.T.dot(loss_grd)
            
            # cập nhật tham số
            learning_rate = 0.01
            thetas = thetas - learning_rate*gradients
            thetas_path.append(thetas)
            
            loss_mean = np.sum(loss)/minibatch_size
            losses.append(loss_mean)

    return thetas_path, losses

mbgd_thetas, losses = mini_batch_gradient_descent()

# in loss cho 100 sample đầu
x_axis = list(range(100))
plt.plot(x_axis,losses[:100], color="r")
plt.show()

 

Giá trị loss cho 100 mini-batch đầu tiên

 

Huấn luyện theo BGD

def batch_gradient_descent():
    n_iterations = 500
    learning_rate = 0.01
    
    # khởi tạo giá trị tham số
    thetas = np.random.randn(14, 1)
    thetas_path = [thetas]
    losses = []
    
    for i in range(n_iterations):
        # tính output
        output = X_b.dot(thetas)
        
        # tính loss
        loss = (output - y)**2        
                
        # tính đạo hàm cho loss
        loss_grd = 2*(output - y)/m
        
        # tính đạo hàm cho các tham số
        gradients = X_b.T.dot(loss_grd)
        
        # cập nhật tham số
        thetas = thetas - learning_rate*gradients
        thetas_path.append(thetas)
        
        mean_loss = np.sum(loss)/m
        losses.append(mean_loss)

    return thetas_path, losses

bgd_thetas, losses = batch_gradient_descent()

# in loss cho 100 sample đầu
x_axis = list(range(500))
plt.plot(x_axis,losses[:500], color="r")
plt.show()

 

Giá trị loss cho 100 epoch đầu tiên