Coverage for Adifpy/differentiate/evaluator.py: 72%
29 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-11-22 19:44 -0500
« prev ^ index » next coverage.py v6.5.0, created at 2022-11-22 19:44 -0500
1"""Automatic Differentiation object"""
3from typing import Callable
5import numpy as np
7from Adifpy.differentiate.forward_mode import forward_mode
8from Adifpy.differentiate.reverse_mode import reverse_mode
11class Evaluator:
12 """AD evaluation object
14 >>> my_evaluator = Evaluator(lambda x: x*x)
15 >>> my_evaluator.eval(1)
16 (1, 2)
17 >>> my_evaluator.eval(3)
18 (9, 9)
19 """
21 def __init__(self, fn: Callable):
22 self.fn = fn
24 def eval(self, pt, **kwargs):
25 """Perform AD on this Evaluator's function, at this point
27 Args:
28 pt (float | iterable): the point or vector at which to evaluate the function
29 seed_vector (iterable, optional): the seed vector, if the function has vector input
30 force_mode (str, optional): either 'forward' or 'reverse' for forcing AD mode
32 Returns:
33 ADEvaluated: the evaluated AD object
34 """
35 shape = np.shape(pt)
36 self.input_dim = 1 if shape == () else shape[0]
38 # Ensure that a seed vector is provided for vector functions
39 if self.input_dim != 1 and 'seed_vector' not in kwargs:
40 raise AttributeError('For vector functions, `seed_vector` argument is required')
41 elif 'seed_vector' not in kwargs:
42 kwargs['seed_vector'] = [1]
44 # Set the output dimension (and ensure the function is valid)
45 try:
46 fn_output = self.fn(pt)
48 # TODO: Check for invalid functions (null returns, etc)
50 self.output_dim = 1 if type(fn_output) in [int, float] else len(fn_output)
51 except Exception as error:
52 raise RuntimeError('Evaluator function failed') from error
54 # Decide which AD mode to use, either depending on forced user input or optimized for performance
55 if 'force_mode' in kwargs:
56 match kwargs['force_mode']:
57 case 'forward':
58 differentiator = forward_mode
59 case 'reverse':
60 differentiator = reverse_mode
61 case _:
62 raise ValueError('`force_mode` argument must be either `forward` or `reverse`')
63 else:
64 differentiator = forward_mode if self.input_dim < self.output_dim else reverse_mode
66 return differentiator(func=self.fn, pt=pt, seed_vector=kwargs['seed_vector'])