Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1#!/usr/bin/env python3 

2 

3from FADiff import FADiff 

4 

5 

6class Scal: 

7 _tmp_part_der = 0 

8 

9 def __init__(self, val, der=None, parents=[], 

10 roots=[], name=None, new_input=False): 

11 self._val = val 

12 self._grad = 0 # TODO: Not sure if need 

13 if new_input: 

14 self._der = {} 

15 for var in FADiff._revscal_inputs: 

16 self._der[var] = 0 

17 var._der[self] = 0 

18 self._der[self] = der 

19 FADiff._revscal_inputs.append(self) 

20 else: 

21 self._der = der 

22 self._name = name 

23 self._parents = parents 

24 self._root_inputs = roots 

25 

26 def __add__(self, other): 

27 try: 

28 der = {} 

29 for var, part_der in self._der.items(): 

30 der[var] = part_der + other._der.get(var) 

31 parents = [self, other] 

32 roots = self._set_roots(self, other) 

33 return Scal(self._val + other._val, der, parents, roots) 

34 except AttributeError: 

35 parents = [self] 

36 roots = self._set_roots(self) 

37 return Scal(self._val + other, self._der, parents, roots) 

38 

39 def __radd__(self, other): 

40 return self.__add__(other) 

41 

42 # TODO 

43 def __mul__(self, other): 

44 try: 

45 der = {} 

46 for var, part_der in self._der.items(): 

47 der[var] = self._val * other._der.get(var) +\ 

48 part_der * other._val 

49 parents = [self, other] 

50 roots = self._set_roots(self, other) 

51 return Scal(self._val * other._val, der, parents, roots) 

52 except AttributeError: 

53 der = {} 

54 for var, part_der in self._der.items(): 

55 der[var] = part_der * other 

56 parents = [self] 

57 roots = self._set_roots(self, other) 

58 return Scal(self._val * other, der, parents, roots) 

59 

60 def __rmul__(self, other): 

61 return self.__mul__(other) 

62 

63 @property 

64 def val(self): 

65 return [self._val] 

66 

67 @property 

68 def der(self): 

69 parents = [] 

70 for var in self._der.keys(): 

71 if var in self._root_inputs: 

72 Scal._tmp_part_der = 1 

73 self._back_trace(var) 

74 parents.append(Scal._tmp_part_der) 

75 return parents 

76 

77 def _back_trace(self, var): 

78 if not self._parents: # Base case (at root var) 

79 return 

80 parent = None # TODO: Raise exception if no parent found? 

81 for par in self._parents: # Find parent with partial der wrt var 

82 if var == par or var in par._root_inputs: 

83 parent = par 

84 break 

85 Scal._tmp_part_der = Scal._tmp_part_der * self._der.get(var) 

86 parent._back_trace(var) 

87 

88 @staticmethod 

89 def _set_roots(var1, var2=None): 

90 roots = [] 

91 if not var1._parents and var1 in FADiff._revscal_inputs: # Root parent 

92 roots.append(var1) 

93 else: 

94 for root in var1._root_inputs: 

95 roots.append(root) 

96 if var2: 

97 if not var2._parents and var2 in FADiff._revscal_inputs: # Root parent 

98 roots.append(var2) 

99 else: 

100 for root in var2._root_inputs: 

101 roots.append(root) 

102 return roots