Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1#!/usr/bin/env python3 

2 

3from FADiff import FADiff 

4 

5 

6class Scal: 

7 """ 

8 A class to... 

9 """ 

10 def __init__(self, val, der=None, parents=[], name=None, new_input=False): 

11 """ 

12 Constructs all the... 

13 

14 Parameters 

15 ---------- 

16 val : float 

17 value of the scalar variable 

18 der : float, dictionary 

19 derivative of the scalar variable 

20 parents : list of Scal objects 

21 the parent/grandparent vars of the variable 

22 name : str 

23 the name of the variable 

24 new_input : boolean 

25 if variable is an input variable 

26 """ 

27 self._val = val 

28 if new_input: # Creating input var? 

29 self._der = {} # Add gradient dict for new var 

30 for var in FADiff._fadscal_inputs: # Update gradient dicts for all vars 

31 self._der[var] = 0 # Partial der of others as 0 in self 

32 var._der[self] = 0 # Self's partial der as 0 in others 

33 self._der[self] = der # Self's partial der in self 

34 FADiff._fadscal_inputs.append(self) # Add self to global vars list 

35 else: 

36 self._der = der 

37 self._name = name # TODO: Utilize if have time? 

38 self._parents = parents 

39 

40 def __add__(self, other): 

41 """ 

42 Adds... 

43 

44 Parameters 

45 ---------- 

46 other : Scal, constant 

47 the Scal object or constant being added to self 

48 

49 Returns 

50 ------- 

51 new Scal instance 

52 """ 

53 try: 

54 der = {} 

55 for var, part_der in self._der.items(): 

56 der[var] = part_der + other._der.get(var) 

57 parents = self._set_parents(self, other) 

58 return Scal(self._val + other._val, der, parents) 

59 except AttributeError: 

60 parents = self._set_parents(self) 

61 return Scal(self._val + other, self._der, parents) 

62 

63 def __radd__(self, other): 

64 return self.__add__(other) 

65 

66 def __mul__(self, other): 

67 try: 

68 der = {} 

69 for var, part_der in self._der.items(): 

70 der[var] = self._val * other._der.get(var) +\ 

71 part_der * other._val 

72 parents = self._set_parents(self, other) 

73 return Scal(self._val * other._val, der, parents) 

74 except AttributeError: 

75 der = {} 

76 for var, part_der in self._der.items(): 

77 der[var] = part_der * other 

78 parents = self._set_parents(self) 

79 return Scal(self._val * other, der, parents) 

80 

81 def __rmul__(self, other): 

82 return self.__mul__(other) 

83 

84 @property 

85 def val(self): 

86 return [self._val] 

87 

88 @property 

89 def der(self): 

90 '''Returns partial derivatives wrt all root input vars used''' 

91 parents = [] 

92 for var, part_der in self._der.items(): 

93 if var in self._parents: 

94 parents.append(part_der) 

95 if parents: # For output vars 

96 return parents 

97 elif self in FADiff._fadscal_inputs: # For input vars (no parents) 

98 return [self._der[self]] 

99 

100 @staticmethod 

101 def _set_parents(var1, var2=None): 

102 '''Sets parent/grandparent vars (including root input vars used)''' 

103 parents = [] 

104 parents.append(var1) 

105 for parent in var1._parents: 

106 parents.append(parent) 

107 if var2: 

108 parents.append(var2) 

109 for parent in var2._parents: 

110 parents.append(parent) 

111 parents = list(set(parents)) 

112 return parents