在PyTorch中有四種類型的乘法運(yùn)算(位置乘法、點(diǎn)積、矩陣與向量乘法、矩陣乘法),非常容易搞混,我們一起來看看這四種乘法運(yùn)算的區(qū)別。
位置乘法
先構(gòu)建兩個(gè)張量a,b他們都是4行5列。
a = torch.arange(20).reshape([4,5])
b = torch.randn([4,5])
位置乘法,顧名思義就是將兩個(gè)張量對應(yīng)位置的元素進(jìn)行乘法運(yùn)算,運(yùn)算符是*。
可以是兩個(gè)張量相乘,也可以是標(biāo)量和張量相乘。
標(biāo)量與張量相乘,是用標(biāo)量與張量的每個(gè)元素相乘,結(jié)果張量的形狀不變。
4 * a
兩個(gè)張量相乘,是對應(yīng)位置的元素相乘,結(jié)果張量的形狀不變。
a * b
點(diǎn)積
點(diǎn)積是兩個(gè)向量(也就是一維張量)對應(yīng)位置的元素相乘后求和,結(jié)果是一個(gè)標(biāo)量,使用dot函數(shù)進(jìn)行計(jì)算。
先構(gòu)建兩個(gè)向量a、b,點(diǎn)積操作要求兩個(gè)向量的數(shù)據(jù)類型要一致,因此a中指定數(shù)據(jù)類型為float。
a = torch.arange(6, dtype=torch.float32)
b = torch.ones(6)
執(zhí)行點(diǎn)積操作,結(jié)果是一個(gè)標(biāo)量。
torch.dot(a,b)
矩陣與向量乘法
矩陣(二維張量)與向量(一維張量)的乘法是將矩陣的每一行與向量進(jìn)行點(diǎn)積,要求矩陣的列維數(shù)與向量的維數(shù)相同,結(jié)果的維數(shù)與行數(shù)相同。
使用mv函數(shù)進(jìn)行運(yùn)算。
構(gòu)建一個(gè)4行5列的矩陣和一個(gè)維數(shù)為5的向量。
a = torch.arange(20,dtype=torch.float32).reshape([4,5])
b = torch.ones(5)
使用mv函數(shù)相乘后,結(jié)果是維數(shù)為4的向量。
torch.mv(a,b)
矩陣乘法
矩陣(二維張量)乘法是用第一個(gè)矩陣的行向量與第二個(gè)矩陣的列向量進(jìn)行點(diǎn)積,要求第一個(gè)矩陣的列數(shù)與第二個(gè)矩陣的行數(shù)相同。
使用mm函數(shù)進(jìn)行運(yùn)算。
構(gòu)建兩個(gè)矩陣,一個(gè)4行5列,一個(gè)5行6列
a = torch.arange(20,dtype=torch.float32).reshape([4,5])
b = torch.randn([5,6])
使用mm函數(shù)相乘后,結(jié)果是4行6列的矩陣。
torch.mm(a,b)