Bitte lesen Sie die Dokumentation zu sorgfältig durch https://pytorch.org/docs/stable/tensors.html#torch.Tensor.backward um es besser zu verstehen.
Standardmäßig erwartet pytorch, dass backward()der letzte Ausgang des Netzwerks aufgerufen wird - die Verlustfunktion. Der Verlust Funktion gibt stets eine skalare und daher die Gradienten der skalaren Verlust WRT alle anderen Variablen / Parameter gut definiert ist (der Kettenregel).
Daher wird standardmäßig backward()ein Skalartensor aufgerufen und erwartet keine Argumente.
Zum Beispiel:
a = torch.tensor([[1,2,3],[4,5,6]], dtype=torch.float, requires_grad=True)
for i in range(2):
for j in range(3):
out = a[i,j] * a[i,j]
out.backward()
print(a.grad)
ergibt
tensor([[ 2., 4., 6.],
[ 8., 10., 12.]])
Wie erwartet : d(a^2)/da = 2a.
Wenn Sie jedoch backwardden 2-mal-3- outTensor aufrufen (keine Skalarfunktion mehr), was erwarten Sie a.graddann? Sie benötigen tatsächlich eine 2-mal-3-mal-2-mal-3-Ausgabe: d out[i,j] / d a[k,l](!)
Pytorch unterstützt diese nicht skalaren Funktionsableitungen nicht. Stattdessen nimmt Pytorch an, dass outes sich nur um einen Zwischentensor handelt, und irgendwo "stromaufwärts" gibt es eine Skalarverlustfunktion , die durch die Kettenregel bereitgestellt wird d loss/ d out[i,j]. Dieser "Upstream" -Gradient hat die Größe 2 x 3, und dies ist tatsächlich das Argument, das Sie backwardin diesem Fall angeben: out.backward(g)wo g_ij = d loss/ d out_ij.
Die Gradienten werden dann durch Kettenregel berechnet d loss / d a[i,j] = (d loss/d out[i,j]) * (d out[i,j] / d a[i,j])
Da Sie aals "Upstream" -Gradienten angegeben haben, haben Sie
a.grad[i,j] = 2 * a[i,j] * a[i,j]
Wenn Sie die "Upstream" -Gradienten als alle bereitstellen würden
out.backward(torch.ones(2,3))
print(a.grad)
ergibt
tensor([[ 2., 4., 6.],
[ 8., 10., 12.]])
Wie erwartet.
Es ist alles in der Kettenregel.