广场遇到溢出
Overflow encountered in square
我尝试搜索我遇到的溢出错误,但没有成功。
当我 运行 这个程序时,我得到 运行 对我来说毫无意义的时间错误。
这是我使用的数据:https://pastebin.com/MLWvUarm
import numpy as np
def loadData():
data = np.loadtxt('data.txt', delimiter=',')
x = np.c_[data[:,0:2]]
y = np.c_[data[:,-1]]
return x, y
def hypothesis(x, theta):
h = x.dot(theta)
return h
def computeCost(x, y, theta):
m = np.size(y, 0)
h = hypothesis(x, theta)
J = (1/(2*m)) * np.sum(np.square(h-y))
return J
def gradient_descent(x, y, theta, alpha, mxIT):
m = np.size(y, 0)
J_history = np.zeros((mxIT, 1))
for it in range(mxIT):
hyp = hypothesis(x, theta)
err = hyp - y
theta = theta - (alpha/m) * (x.T.dot(err))
J_history[it] = computeCost(x, y, theta)
return theta, J_history
def main():
x, y = loadData()
x = np.c_[np.ones(x.shape[0]), x]
theta = np.zeros((np.size(x, 1), 1))
alpha = 0.01
mxIT = 400
theta, j_his = gradient_descent(x, y, theta, alpha, mxIT)
print(theta)
if __name__ == "__main__":
main()
如何解决这个问题?
加载x
后,尝试除以均值,看是否收敛。 Link 均值文档:numpy.mean
...
x, y = loadData()
x = x / x.mean(axis=0, keepdims=True)
x = np.c_[np.ones(x.shape[0]), x]
...
目前它似乎有分歧,这会产生非常高的错误,numpy 会抱怨。您可以从您在 J_history
.
中维护的成本历史记录中看到这一点
我尝试搜索我遇到的溢出错误,但没有成功。
当我 运行 这个程序时,我得到 运行 对我来说毫无意义的时间错误。 这是我使用的数据:https://pastebin.com/MLWvUarm
import numpy as np
def loadData():
data = np.loadtxt('data.txt', delimiter=',')
x = np.c_[data[:,0:2]]
y = np.c_[data[:,-1]]
return x, y
def hypothesis(x, theta):
h = x.dot(theta)
return h
def computeCost(x, y, theta):
m = np.size(y, 0)
h = hypothesis(x, theta)
J = (1/(2*m)) * np.sum(np.square(h-y))
return J
def gradient_descent(x, y, theta, alpha, mxIT):
m = np.size(y, 0)
J_history = np.zeros((mxIT, 1))
for it in range(mxIT):
hyp = hypothesis(x, theta)
err = hyp - y
theta = theta - (alpha/m) * (x.T.dot(err))
J_history[it] = computeCost(x, y, theta)
return theta, J_history
def main():
x, y = loadData()
x = np.c_[np.ones(x.shape[0]), x]
theta = np.zeros((np.size(x, 1), 1))
alpha = 0.01
mxIT = 400
theta, j_his = gradient_descent(x, y, theta, alpha, mxIT)
print(theta)
if __name__ == "__main__":
main()
如何解决这个问题?
加载x
后,尝试除以均值,看是否收敛。 Link 均值文档:numpy.mean
...
x, y = loadData()
x = x / x.mean(axis=0, keepdims=True)
x = np.c_[np.ones(x.shape[0]), x]
...
目前它似乎有分歧,这会产生非常高的错误,numpy 会抱怨。您可以从您在 J_history
.