Coverage for /usr/local/lib/python3.10/dist-packages/Adifpy/differentiate/forward_mode.py: 100%

15 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-12-07 00:47 -0500

1import numpy as np 

2 

3from Adifpy.differentiate.dual_number import DualNumber 

4from Adifpy.differentiate.helpers import isscalar 

5 

6 

7def forward_mode(func, pt, seed_vector): 

8 """Compute the value and directional derivative at a point, for a given function. 

9 

10 Forward mode AD calculates the derivative of a function at point 

11 with respect to some seed vector, at machine precision. 

12 

13 Args: 

14 func (Callable): the function on which to evaluate 

15 pt (float | ndarray): the point at which to evaluate 

16 seed_vector (ndarray): the direction in which to evaluate 

17 defaults to 1 when the function's input space is R 

18 

19 Returns: 

20 If the function's output space is R, a tuple of the value and directional derivative 

21 Otherwise, a tuple of two lists: the values and directional derivatives for each component 

22 

23 >>> f = lambda x: x**2 + 3*x 

24 >>> forward_mode(f, 1, seed_vector=1) 

25 (4, 5) 

26 >>> f = lambda x: [np.sin(x[0]), np.cos(x[1])] 

27 

28 NOTE: We expect a -0.0 here, since the derivative of cos is -sin 

29 >>> forward_mode(f, [0, 0], [1, 0]) 

30 ([0.0, 1.0], [1.0, -0.0]) 

31 """ 

32 

33 if isinstance(seed_vector, np.ndarray) or type(seed_vector) == list: 

34 # Create a list of dual numbers to pass through the function 

35 array_pt = np.array([DualNumber(p, seed) for p, seed in zip(pt, seed_vector)]) 

36 passed = func(array_pt) 

37 else: 

38 passed = func(DualNumber(pt, seed_vector)) 

39 

40 if isscalar(passed): 

41 return (passed, 0) 

42 elif isinstance(passed, DualNumber): 

43 return (passed.real, passed.dual) 

44 else: 

45 # Some outputs may be scalars (for functions that map to constants) 

46 reals = [d if not isinstance(d, DualNumber) else d.real for d in passed] 

47 duals = [0 if not isinstance(d, DualNumber) else d.dual for d in passed] 

48 

49 return reals, duals