Pytorch 计算两个张量的欧式距离

发布时间:2022-06-29 发布网站:脚本宝典
脚本宝典收集整理的这篇文章主要介绍了Pytorch 计算两个张量的欧式距离脚本宝典觉得挺不错的,现在分享给大家,也给大家做个参考。

1.Pytorch计算公式

a,b为两个张量,且a.size=(B,N,3),b.size()=(B,M,3),计算a中各点到b中各点的距离,返回距离张量c,c.size()=(B,N,M)。不考虑Batch时,可以将理解:c的第i行j列的值表示a中第i个点到b中第j个点的距离。

import torch

def EuclideanDistance(t1,t2):
    dim=len(t1.size())
    if dim==2:
        N,C=t1.size()
        M,_=t2.size()
        dist = -2 * torch.matmul(t1, t2.permute(1, 0))
        dist += torch.sum(t1 ** 2, -1).view(N, 1)
        dist += torch.sum(t2 ** 2, -1).view(1, M)
        dist=torch.sqrt(dist)
        return dist
    elif dim==3:
        B,N,_=t1.size()
        _,M,_=t2.size()
        dist = -2 * torch.matmul(t1, t2.permute(0, 2, 1))
        dist += torch.sum(t1 ** 2, -1).view(B, N, 1)
        dist += torch.sum(t2 ** 2, -1).view(B, 1, M)
        dist=torch.sqrt(dist)
        return dist
    else:
        print('error...')

print(f'dimensional 2.......')
a=torch.Tensor([[0,0],[1,1]])
b=torch.Tensor([[1,0],[3,4]])
print(f'size of a:{a.size()}tsize of b:{b.size()}')
print(f'distance of point a and b is: {EuclideanDistance(a,b)}')
print(f'ndimensional 3.......')
a=torch.unsqueeze(a,dim=0)
b=torch.unsqueeze(b,dim=0)
print(f'size of a:{a.size()}tsize of b:{b.size()}')
print(f'distance of point a and b is: {EuclideanDistance(a,b)}')

2.代码理解

2.1定义待计算张量

现有张量a,b如下:

Pytorch 计算两个张量的欧式距离

2.2距离公式

有距离公式如下:

Pytorch 计算两个张量的欧式距离

2.3分步计算

(1)计算:

d1=-2 * torch.matmul(a, b.permute(0, 2, 1))

(1)结果如下:

Pytorch 计算两个张量的欧式距离

(2)计算:

d2=torch.sum(a** 2, -1)
d3=torch.sum(b** 2, -1)

(2)结果如下:

Pytorch 计算两个张量的欧式距离

当前有:d1.size=(B,N,M),d2.size()=(B,N,1),d3.size()=(B,M,1)

可以看到d1中的i行中保持不变的部分为a中的第i个点,d1中第j列中不变的部分对应b中的j行。因此,只需在d1的行上加上一个d2的对应行,列上加d3的对应行即可。

(3)相加:

d=d1+d2.view(B,N,1)+d3.view(B,1,M)

(3)结果如下:

Pytorch 计算两个张量的欧式距离

 (4)开根

d=torch.sqrt(d)

 

脚本宝典总结

以上是脚本宝典为你收集整理的Pytorch 计算两个张量的欧式距离全部内容,希望文章能够帮你解决Pytorch 计算两个张量的欧式距离所遇到的问题。

如果觉得脚本宝典网站内容还不错,欢迎将脚本宝典推荐好友。

本图文内容来源于网友网络收集整理提供,作为学习参考使用,版权属于原作者。
如您有任何意见或建议可联系处理。小编QQ:384754419,请注明来意。
标签: