如何手动更改模型的参数?
How to change my model's parameters manually?
我有一个模型:
import torch
import torch.nn as nn
import torch.optim as optim
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.fc1 = nn.Linear(1, 3)
self.fc2 = nn.Linear(3, 2)
self.fc3 = nn.Linear(2, 1)
def forward(self, x):
x1 = self.fc1(x)
x = torch.relu(x1)
x2 = self.fc2(x)
x = torch.relu(x2)
x3 = self.fc3(x)
return x3, x2, x1
net = Model()
我正在尝试使用
手动更新参数
i, j = torch.meshgrid(torch.arange(3), torch.arange(2))
i = i.reshape(-1)
j = j.reshape(-1)
update = torch.ones(6,1)
print(i)
print(j)
print(update.squeeze())
print(net.fc2.weight[j,i].data)
net.fc2.weight[j,i].data += update.squeeze()
print(net.fc2.weight[j,i].data)
>>> tensor([0, 0, 1, 1, 2, 2])
tensor([0, 1, 0, 1, 0, 1])
tensor([1., 1., 1., 1., 1., 1.])
tensor([-0.0209, -0.3770, 0.4982, -0.2123, -0.2630, -0.5580])
tensor([-0.0209, -0.3770, 0.4982, -0.2123, -0.2630, -0.5580])
但似乎什么都没有改变。
但是,如果我这样做
print(net.fc2.weight[1].data)
net.fc2.weight[1].data += 1
print(net.fc2.weight[1].data)
>>> tensor([-0.3770, -0.2123, -0.5580])
tensor([0.6230, 0.7877, 0.4420])
他们确实改变了。
我在第一种方法中做错了什么,我怎样才能让它起作用?
你遗漏的要点很简单:当你进行“常量索引”时,你会得到张量的“视图”,否则(即用另一个张量索引)你会得到一个新的张量或一个新节点计算图。
PyTorch
提供了一个 .data_ptr()
方法来查看底层内存指针。
>> net.fc2.weight.data.data_ptr()
2911054070464
>> net.fc2.weight[1].data.data_ptr()
2911054070464
常量索引未更改基础原始数据。但是,使用张量进行索引会创建一个新节点,因此会创建一个新的底层原始内存位置
>> net.fc2.weight[j, i].data.data_ptr()
2911054068672
因此,在您的情况下,您正在使用 net.fc2.weight[j,i]
创建一个新的 tensor/node 并为其分配新值。这就是为什么您的原始张量保持不变的原因。在常量索引情况下,您正在更改相同的内存位置,因此会反映更改。
解决您的问题,而不是这样做
net.fc2.weight[j,i].data += update.squeeze()
这样做
net.fc2.weight.data[j,i] += update.squeeze()
.. 本质上是先抓取底层 .data
然后对其进行索引,这意味着索引操作完全脱离了 autograd 的跟踪机制。
我有一个模型:
import torch
import torch.nn as nn
import torch.optim as optim
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.fc1 = nn.Linear(1, 3)
self.fc2 = nn.Linear(3, 2)
self.fc3 = nn.Linear(2, 1)
def forward(self, x):
x1 = self.fc1(x)
x = torch.relu(x1)
x2 = self.fc2(x)
x = torch.relu(x2)
x3 = self.fc3(x)
return x3, x2, x1
net = Model()
我正在尝试使用
手动更新参数i, j = torch.meshgrid(torch.arange(3), torch.arange(2))
i = i.reshape(-1)
j = j.reshape(-1)
update = torch.ones(6,1)
print(i)
print(j)
print(update.squeeze())
print(net.fc2.weight[j,i].data)
net.fc2.weight[j,i].data += update.squeeze()
print(net.fc2.weight[j,i].data)
>>> tensor([0, 0, 1, 1, 2, 2])
tensor([0, 1, 0, 1, 0, 1])
tensor([1., 1., 1., 1., 1., 1.])
tensor([-0.0209, -0.3770, 0.4982, -0.2123, -0.2630, -0.5580])
tensor([-0.0209, -0.3770, 0.4982, -0.2123, -0.2630, -0.5580])
但似乎什么都没有改变。
但是,如果我这样做
print(net.fc2.weight[1].data)
net.fc2.weight[1].data += 1
print(net.fc2.weight[1].data)
>>> tensor([-0.3770, -0.2123, -0.5580])
tensor([0.6230, 0.7877, 0.4420])
他们确实改变了。
我在第一种方法中做错了什么,我怎样才能让它起作用?
你遗漏的要点很简单:当你进行“常量索引”时,你会得到张量的“视图”,否则(即用另一个张量索引)你会得到一个新的张量或一个新节点计算图。
PyTorch
提供了一个 .data_ptr()
方法来查看底层内存指针。
>> net.fc2.weight.data.data_ptr()
2911054070464
>> net.fc2.weight[1].data.data_ptr()
2911054070464
常量索引未更改基础原始数据。但是,使用张量进行索引会创建一个新节点,因此会创建一个新的底层原始内存位置
>> net.fc2.weight[j, i].data.data_ptr()
2911054068672
因此,在您的情况下,您正在使用 net.fc2.weight[j,i]
创建一个新的 tensor/node 并为其分配新值。这就是为什么您的原始张量保持不变的原因。在常量索引情况下,您正在更改相同的内存位置,因此会反映更改。
解决您的问题,而不是这样做
net.fc2.weight[j,i].data += update.squeeze()
这样做
net.fc2.weight.data[j,i] += update.squeeze()
.. 本质上是先抓取底层 .data
然后对其进行索引,这意味着索引操作完全脱离了 autograd 的跟踪机制。