Hi guys,
I met trouble with unit tests of the gradient descent method at the 4th cell of the 2nd notebook. A little bit of context: a unit test there defines X, y, theta, learning rate, and number of iterations.
self.X = np.array([[1, 2, 4, 5, 8], [1, 3, 6, 8, 9], [1, 4, 5, 4, 2], [1, 5, 1, 6, 8]])
self.y = np.array([7, 6, 5, 4])
self.t = np.array([0.1, 0.2, 0.1, 0.3, 0.8])
self.lr = 0.001
self.iter = 1000
the unit test runs the gradient descent method, gets a final theta, and asserts it is close to the expected theta:
def test_gradient_descent(self):
# Run gradient descent to find optimized theta values
theta, c_history = gradient_descent(self.X, self.y, self.t, self.lr, self.iter)
t = np.round(theta,3)
e_t = np.array([0.241, 0.221, 0.688, -0.19, 0.4])
np.testing.assert_array_almost_equal(t, e_t, decimal=8)
# Check that the cost history is decreasing
for i in range(1, len(cost_history)):
self.assertLessEqual(c_history[i], c_history[i - 1])
This test fails for me. So I took the data and tried to figure out why. Accidentally, I figured out that the cost function(loss function) is higher for the expected than for the calculated theta:
theta_0 = np.array([0.1, 0.2, 0.1, 0.3, 0.8])
X = np.array([[1, 2, 4, 5, 8], [1, 3, 6, 8, 9], [1, 4, 5, 4, 2], [1, 5, 1, 6, 8]])
y = np.array([7, 6, 5, 4])
lr = 0.001
iter = 1000
expected_theta = np.array([0.241, 0.221, 0.688, -0.19, 0.4])
calculated_theta, _ = gradient_descent(X, y, theta_0, lr, iter)
print(f"Expected weights {expected_theta}: cost function {compute_cost(X, y, expected_theta)}")
print(f"Final weights {calculated_theta}: cost function {compute_cost(X, y, calculated_theta)}")
The output for the code is:
Expected weights [ 0.241 0.221 0.688 -0.19 0.4 ]: cost function 0.7826574999999999
Final weights [ 0.35525484 0.31975621 0.78449497 -0.48302338 0.52775727]: cost function 0.5133294839263606
And then I printed weights by iterations and I found weights really close to our expected ones but at iteration 499:
Iteration 499: theta = [ 0.24116448 0.22113463 0.68758194 -0.19034457 0.40039748]
Could you please help me with this issue?