I denne blog vil vi diskutere, hvordan man bruger ' torch.no_grad ” metode i PyTorch.
Hvad er 'torch.no_grad'-metoden i PyTorch?
Det ' torch.no_grad ” metode bruges til styring af kontekst inden for PyTorchs udviklingsramme. Dens formål er at stoppe beregningen af gradienter for sammenhængen mellem de efterfølgende lag af deep learning-modellen. Nytten af denne metode er, når gradienter ikke er påkrævet i en bestemt model, så kan de deaktiveres for at allokere flere hardwareressourcer til behandlingen af modellens træningsløkke.
Hvordan bruger man metoden 'torch.no_grad' i PyTorch?
Gradienter beregnes inden for tilbageløbet i PyTorch. Som standard har PyTorch automatisk differentiering aktiveret for alle maskinlæringsmodeller. Deaktivering af gradientberegning er afgørende for udviklere, der ikke har tilstrækkelige hardwarebehandlingsressourcer.
Følg nedenstående trin for at lære, hvordan du bruger ' torch.no_grad ” metode til at deaktivere beregningen af gradienter i PyTorch:
Trin 1: Start Colab IDE
Google Colaboratory er et glimrende valg af platform til udvikling af projekter, der bruger PyTorch-rammen på grund af dets dedikerede GPU'er. Gå til Colab internet side og åbne en ' Ny notesbog ' som vist:
Trin 2: Installer og importer Torch Library
Al funktionaliteten i PyTorch er indkapslet af ' fakkel ” bibliotek. Dets installation og import er afgørende, før arbejdet påbegyndes. Det ' !pip ” installationspakken af Python bruges til at installere biblioteker, og den importeres til projektet ved hjælp af ” importere kommando:
!pip installer lommelygteimport lommelygte
Trin 3: Definer en PyTorch-tensor med en gradient
Tilføj en PyTorch-tensor til projektet ved hjælp af ' torch.tensor() ” metode. Giv den derefter en gyldig gradient ved hjælp af ' requires_grad=Sandt ” metode som vist i koden nedenfor:
A = torch.tensor([5.0], requires_grad=True)
Trin 4: Brug 'torch.no_grad'-metoden til at fjerne gradienten
Fjern derefter gradienten fra den tidligere definerede tensor ved hjælp af ' torch.no_grad ” metode:
med torch.no_grad():B = A**2 + 16
Ovenstående kode fungerer som følger:
- Det ' no_grad() '-metoden bruges i en ' med ” sløjfe.
- Hver tensor indeholdt i løkken har sin gradient fjernet.
- Til sidst skal du definere en aritmetisk eksempelberegning ved hjælp af den tidligere definerede tensor og tildele den til ' B variabel som vist ovenfor:
Trin 5: Bekræft gradientfjernelsen
Det sidste trin er at bekræfte, hvad der lige blev gjort. Gradienten fra tensor ' EN ' blev fjernet, og det skal kontrolleres i outputtet ved hjælp af ' Print() ” metode:
print('Gradientberegning med torch.no_grad: ', A.grad)print('\nOriginal Tensor: ', A)
print('\nEksempel på aritmetisk beregning: ', B)
Ovenstående kode fungerer som følger:
- Det ' grad 'metoden giver os gradienten af tensor' EN ”. Det viser ingen i outputtet nedenfor, fordi gradienten er blevet fjernet ved hjælp af ' torch.no_grad ” metode.
- Den originale tensor viser stadig, at den har sin gradient set fra ' requires_grad=Sandt ” udsagn i outputtet.
- Til sidst viser den aritmetiske prøveberegning resultatet af den tidligere definerede ligning:
Bemærk : Du kan få adgang til vores Colab Notebook her link .
Pro-Tip
Det ' torch.no_grad ”-metoden er ideel, hvor gradienterne ikke er nødvendige, eller når der er behov for at reducere behandlingsbelastningen på hardwaren. En anden brug af denne metode er under inferens, fordi modellen kun bruges til at lave forudsigelser baseret på nye data. Da der ikke er nogen træning involveret, giver det fuldstændig mening blot at deaktivere beregningen af gradienter.
Succes! Vi har vist dig, hvordan du bruger 'torch.no_grad'-metoden til at deaktivere gradienter i PyTorch.
Konklusion
Brug ' torch.no_grad ” metode i PyTorch ved at definere den inde i en ” med ”-løkke og alle tensorer indeholdt i vil få deres gradient fjernet. Dette vil medføre forbedringer i behandlingshastigheder og forhindre akkumulering af gradienter i træningsløkken. I denne blog har vi vist, hvordan dette ' torch.no_grad ”-metoden kan bruges til at deaktivere gradienterne for udvalgte tensorer i PyTorch.