Hvordan bruger man metoden 'torch.argmax()' i PyTorch?

Hvordan Bruger Man Metoden Torch Argmax I Pytorch



I PyTorch er ' torch.argmax() ”-metoden er en indbygget funktion, der returnerer indekser for maksimalværdier af en bestemt tensor på tværs af en given dimension. Brugere bruger denne funktion, når de arbejder med tensorer og ønsker at finde indekset for den maksimale værdi langs en tensors givne dimension. Desuden kan denne metode også være nyttig til klassificering, hvor brugere ønsker at vide, hvilken klasse der har størst sandsynlighed.

Denne blog vil eksemplificere metoden til at bruge 'torch.argmax()' metoden i PyTorch.

Hvordan bruger man metoden 'torch.argmax()' i PyTorch?

Metoden 'torch.argmax()' tager enhver 1D- eller 2D-tensor som input og returnerer en tensor, der indeholder indekserne/indekserne for de maksimale værdier langs den givne dimension.







Syntaksen for 'torch.argmax()'-metoden er angivet nedenfor:



fakkel. argmax ( < input_tensor > )

For at bruge denne metode i PyTorch skal du gennemgå følgende eksempler for en bedre forståelse:



Eksempel 1: Brug metoden 'torch.argmax()' med 1D Tensor

I det første eksempel vil vi oprette en 1D-tensor og bruge 'torch.argmax()'-metoden med den. Lad os følge nedenstående trin-for-trin procedure:





Trin 1: Importer PyTorch Library

Først skal du importere ' fakkel ”-biblioteket for at bruge metoden “torch.argmax()”:

importere fakkel

Trin 2: Opret 1D Tensor

Opret derefter en 1D-tensor og udskriv dens elementer. Her laver vi følgende ' Tiere 1 ' tensor fra en liste ved hjælp af ' torch.tensor() ' funktion:



Tiere 1 = fakkel. tensor ( [ 5 , 0 , - 8 , 1 , 9 , 7 ] )

Print ( Tiere 1 )

Dette har skabt en 1D-tensor som vist nedenfor:

Trin 3: Find indekser for maksimal værdi

Brug nu ' torch.argmax() ”-funktion til at finde indekset/indeksene for den maksimale værdi i Tiere 1 ' tensor:

T1_ind = fakkel. argmax ( Tiere 1 )

Trin 4: Udskriv indeks over maksimal værdi

Til sidst skal du vise maksimumværdiens indeks i inputtensoren:

Print ( 'Indekser:' , T1_ind )

Nedenstående output viser indekset for den maksimale værdi i ' Tiere 1 ' tensor dvs. 4. Det betyder, at den højeste værdi af tensoren er ved det 4. indeks, som er ' 9 ”:

Eksempel 2: Brug metoden 'torch.argmax()' med 2D Tensor

I det andet eksempel vil vi oprette en 2D-tensor og bruge 'torch.argmax()'-metoden med den. Lad os følge de angivne trin:

Trin 1: Importer PyTorch Library

Først skal du importere ' fakkel ”-biblioteket for at bruge metoden “torch.argmax()”:

importere fakkel

Trin 2: Opret 2D Tensor

Brug derefter ' torch.tensor() ”-funktion til at skabe en 2D-tensor og udskrive dens elementer. Her laver vi følgende ' Tiere 2 '2D tensor:

Tiere 2 = fakkel. tensor ( [ [ 4 , 1 , - 7 ] , [ femten , 6 , 0 ] , [ - 7 , 9 , 2 ] ] )

Print ( Tiere 2 )

Dette har skabt en 2D-tensor som vist nedenfor:

Trin 3: Find indekser for maksimal værdi

Find nu indekset for den maksimale værdi i ' Tiere 2 ' tensor ved at bruge ' torch.argmax() ' funktion:

T2_ind = fakkel. argmax ( Tiere 2 )

Trin 4: Udskriv indeks over maksimal værdi

Vis endelig maksimumværdiens indeks i inputtensoren:

Print ( 'Indekser:' , T2_ind )

Ifølge nedenstående output er indekset for den maksimale værdi i ' Tiere 2 ' tensor er '3'. Det betyder, at den højeste værdi af tensoren er ved det 3. indeks, som er ' femten ”:

Trin 5: Find indekser for maksimal værdi langs kolonner

Desuden kan brugere også finde indekserne/indekserne for de maksimale værdier langs hver kolonne i en tensor. For eksempel kan vi bruge ' dim=0 ” argument med funktionen “torch.argmax()”. Den finder de maksimale værdiers indekser langs kolonner i ' Tiere 2 ” tensor og udskriver derefter disse indekser:

col_index = fakkel. argmax ( Tiere 2 , svag = 0 )

Print ( 'Indekser i kolonner:' , col_index )

Nedenstående output viser indekserne for de maksimale værdier langs hver kolonne af tensoren:

Trin 6: Find indekser for maksimal værdi langs rækker

På samme måde kan brugere også finde indekserne/indekserne for de maksimale værdier langs hver række af en tensor. Brug for eksempel ' dim=1 ” argument med funktionen “torch.argmax()” for at finde de maksimale værdiers indekser langs rækker i “Tens2”-tensoren og derefter udskrive disse indekser:

række_indeks = fakkel. argmax ( Tiere 2 , svag = 1 )

Print ( 'Indekser i rækker:' , række_indeks )

Den maksimale værdis indekser langs hver række af en 'Tens2' tensor kan ses nedenfor:

Vi har effektivt forklaret metoden til at bruge 'torch.argmax()' metoden i PyTorch.

Bemærk : Du kan få adgang til vores Google Colab Notebook her link .

Konklusion

For at bruge 'torch.argmax()'-metoden i PyTorch skal du først importere ' fakkel ” bibliotek. Opret derefter den ønskede 1D- eller 2D-tensor og se dens elementer. Brug derefter ' torch.argmax() ” metode til at finde/beregne indekserne/indekserne for de maksimale værdier i tensoren. Desuden kan brugere også finde den maksimale værdis indekser langs hver række eller kolonne i tensoren ved hjælp af ' svag ' argument. Vis endelig maksimumværdiens indeks i inputtensoren. Denne blog har eksemplificeret metoden til at bruge 'torch.argmax()' metoden i PyTorch.