Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import torch 

2from torch import Tensor 

3 

4 

5class CosineLoss(torch.nn.CosineSimilarity): 

6 """CosineLoss Implements a simple cosine similarity based loss.""" 

7 

8 def __init__(self, *args, **kwargs) -> None: 

9 """__init__ Instantiates the class. 

10 

11 All arguments are passed to `torch.nn.CosineSimilarity` 

12 """ 

13 super().__init__(*args, **kwargs) 

14 

15 def forward(self, truth: Tensor, prediction: Tensor) -> Tensor: 

16 """Forward calculates the loss. 

17 

18 Parameters 

19 ---------- 

20 truth : Tensor 

21 prediction : Tensor 

22 

23 Returns 

24 ------- 

25 Tensor 

26 

27 Examples 

28 -------- 

29 >>> loss = CosineLoss(dim=1, eps=1e-4) 

30 >>> loss(torch.ones([1,2,5]), torch.zeros([1,2,5])) 

31 tensor([[1., 1., 1., 1., 1.]]) 

32 >>> loss(torch.ones([1,2,5]), 5*torch.zeros([1,2,5])) 

33 tensor([[1., 1., 1., 1., 1.]]) 

34 >>> loss(torch.zeros([1,2,5]), torch.zeros([1,2,5])) 

35 tensor([[0., 0., 0., 0., 0.]]) 

36 """ 

37 out = super().forward(truth, prediction) 

38 out = 1 - out 

39 return out 

40 

41 

42class PearsonCorrelation(torch.nn.Module): 

43 """PearsonCorrelation Implements a simple pearson correlation.""" 

44 

45 def __init__(self, axis=1, eps=1e-4): 

46 """__init__ Instantiates the class. 

47 

48 Creates a callable object to calculate the pearson correlation on an axis 

49 

50 Parameters 

51 ---------- 

52 axis : int, optional 

53 The axis over which the correlation is calculated. 

54 For instance, if the input has shape [5, 500] and the axis is set 

55 to 1, the output will be of shape [5]. On the other hand, if the axis 

56 is set to 0, the output will have shape [500], by default 1 

57 eps : float, optional 

58 Number to be added to to prevent division by 0, by default 1e-4 

59 """ 

60 super().__init__() 

61 self.axis = axis 

62 self.eps = eps 

63 

64 def forward(self, x, y): 

65 """Forward calculates the loss. 

66 

67 Parameters 

68 ---------- 

69 truth : Tensor 

70 prediction : Tensor 

71 

72 Returns 

73 ------- 

74 Tensor 

75 

76 Examples 

77 -------- 

78 >>> loss = PearsonCorrelation(axis=1, eps=1e-4) 

79 >>> loss(torch.ones([1,2,5]), torch.zeros([1,2,5])) 

80 tensor([[1., 1., 1., 1., 1.]]) 

81 >>> loss(torch.ones([1,2,5]), 5*torch.zeros([1,2,5])) 

82 tensor([[1., 1., 1., 1., 1.]]) 

83 >>> loss(torch.zeros([1,2,5]), torch.zeros([1,2,5])) 

84 tensor([[0., 0., 0., 0., 0.]]) 

85 >>> out = loss(torch.rand([5, 174]), torch.rand([5, 174])) 

86 >>> out.shape 

87 torch.Size([5]) 

88 >>> loss = PearsonCorrelation(axis=0, eps=1e-4) 

89 >>> out = loss(torch.rand([5, 174]), torch.rand([5, 174])) 

90 >>> out.shape 

91 torch.Size([174]) 

92 """ 

93 vx = x - torch.mean(x, axis=self.axis).unsqueeze(self.axis) 

94 vy = y - torch.mean(y, axis=self.axis).unsqueeze(self.axis) 

95 

96 num = torch.sum(vx * vy, axis=self.axis) 

97 denom_1 = torch.sqrt(torch.sum(vx ** 2, axis=self.axis)) 

98 denom_2 = torch.sqrt(torch.sum(vy ** 2, axis=self.axis)) 

99 denom = (denom_1 * denom_2) + self.eps 

100 cost = num / denom 

101 return cost