# Source code for impedance.models.circuits.fitting

import warnings

import numpy as np
from scipy.linalg import inv
from scipy.optimize import curve_fit, basinhopping

from .elements import circuit_elements, get_element_from_name

ints = '0123456789'

[docs]def rmse(a, b):
"""
A function which calculates the root mean squared error
between two vectors.

Notes
---------
.. math::

RMSE = \\sqrt{\\frac{1}{n}(a-b)^2}
"""

n = len(a)
return np.linalg.norm(a - b) / np.sqrt(n)

[docs]def set_default_bounds(circuit, constants={}):
""" This function sets default bounds for optimization.

set_default_bounds sets bounds of 0 and np.inf for all parameters,
except the CPE and La alphas which have an upper bound of 1.

Parameters
-----------------
circuit : string
String defining the equivalent circuit to be fit

constants : dictionary, optional
Parameters and their values to hold constant during fitting
(e.g. {"RO": 0.1}). Defaults to {}

Returns
------------
bounds : 2-tuple of array_like
Lower and upper bounds on parameters.
"""

# extract the elements from the circuit
extracted_elements = extract_circuit_elements(circuit)

# loop through bounds
lower_bounds, upper_bounds = [], []
for elem in extracted_elements:
raw_element = get_element_from_name(elem)
for i in range(check_and_eval(raw_element).num_params):
if elem in constants or elem + f'_{i}' in constants:
continue
if raw_element in ['CPE', 'La'] and i == 1:
upper_bounds.append(1)
else:
upper_bounds.append(np.inf)
lower_bounds.append(0)

bounds = ((lower_bounds), (upper_bounds))
return bounds

[docs]def circuit_fit(frequencies, impedances, circuit, initial_guess, constants={},
bounds=None, weight_by_modulus=False, global_opt=False,
**kwargs):

""" Main function for fitting an equivalent circuit to data.

By default, this function uses scipy.optimize.curve_fit
<https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.curve_fit.html>_
to fit the equivalent circuit. This function generally works well for
simple circuits. However, the final results may be sensitive to
the initial conditions for more complex circuits. In these cases,
the scipy.optimize.basinhopping
<https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.basinhopping.html>_
global optimization algorithm can be used to attempt a better fit.

Parameters
-----------------
frequencies : numpy array
Frequencies

impedances : numpy array of dtype 'complex128'
Impedances

circuit : string
String defining the equivalent circuit to be fit

initial_guess : list of floats
Initial guesses for the fit parameters

constants : dictionary, optional
Parameters and their values to hold constant during fitting
(e.g. {"RO": 0.1}). Defaults to {}

bounds : 2-tuple of array_like, optional
Lower and upper bounds on parameters. Defaults to bounds on all
parameters of 0 and np.inf, except the CPE alpha
which has an upper bound of 1

weight_by_modulus : bool, optional
Uses the modulus of each data (|Z|) as the weighting factor.
Standard weighting scheme when experimental variances are unavailable.
Only applicable when global_opt = False

global_opt : bool, optional
If global optimization should be used (uses the basinhopping
algorithm). Defaults to False

kwargs :
Keyword arguments passed to scipy.optimize.curve_fit or
scipy.optimize.basinhopping

Returns
------------
p_values : list of floats
best fit parameters for specified equivalent circuit

p_errors : list of floats
one standard deviation error estimates for fit parameters

Notes
---------
Need to do a better job of handling errors in fitting.
Currently, an error of -1 is returned.

"""
Z = impedances

# set upper and lower bounds on a per-element basis
if bounds is None:
bounds = set_default_bounds(circuit, constants=constants)

if not global_opt:
if 'maxfev' not in kwargs:
kwargs['maxfev'] = 1e5
if 'ftol' not in kwargs:
kwargs['ftol'] = 1e-13

# weighting scheme for fitting
if weight_by_modulus:
abs_Z = np.abs(Z)
kwargs['sigma'] = np.hstack([abs_Z, abs_Z])

popt, pcov = curve_fit(wrapCircuit(circuit, constants), frequencies,
np.hstack([Z.real, Z.imag]),
p0=initial_guess, bounds=bounds, **kwargs)

# Calculate one standard deviation error estimates for fit parameters,
# defined as the square root of the diagonal of the covariance matrix.
# https://stackoverflow.com/a/52275674/5144795
perror = np.sqrt(np.diag(pcov))

else:
if 'seed' not in kwargs:
kwargs['seed'] = 0

def opt_function(x):
""" Short function for basinhopping to optimize over.
We want to minimize the RMSE between the fit and the data.

Parameters
----------
x : args
Parameters for optimization.

Returns
-------
function
Returns a function (RMSE as a function of parameters).
"""
return rmse(wrapCircuit(circuit, constants)(frequencies, *x),
np.hstack([Z.real, Z.imag]))

class BasinhoppingBounds(object):
""" Adapted from the basinhopping documetation
https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.basinhopping.html
"""

def __init__(self, xmin, xmax):
self.xmin = np.array(xmin)
self.xmax = np.array(xmax)

def __call__(self, **kwargs):
x = kwargs['x_new']
tmax = bool(np.all(x <= self.xmax))
tmin = bool(np.all(x >= self.xmin))
return tmax and tmin

basinhopping_bounds = BasinhoppingBounds(xmin=bounds[0],
xmax=bounds[1])
results = basinhopping(opt_function, x0=initial_guess,
accept_test=basinhopping_bounds, **kwargs)
popt = results.x

# Calculate perror
jac = results.lowest_optimization_result['jac'][np.newaxis]
try:
# jacobian -> covariance
# https://stats.stackexchange.com/q/231868
pcov = inv(np.dot(jac.T, jac)) * opt_function(popt) ** 2
# covariance -> perror (one standard deviation
# error estimates for fit parameters)
perror = np.sqrt(np.diag(pcov))
except (ValueError, np.linalg.LinAlgError):
warnings.warn('Failed to compute perror')
perror = None

return popt, perror

[docs]def wrapCircuit(circuit, constants):
""" wraps function so we can pass the circuit string """
def wrappedCircuit(frequencies, *parameters):
""" returns a stacked array of real and imaginary impedance
components

Parameters
----------
circuit : string
constants : dict
parameters : list of floats
frequencies : list of floats

Returns
-------
array of floats

"""

x = eval(buildCircuit(circuit, frequencies, *parameters,
constants=constants, eval_string='',
index=0)[0],
circuit_elements)
y_real = np.real(x)
y_imag = np.imag(x)

return np.hstack([y_real, y_imag])
return wrappedCircuit

[docs]def buildCircuit(circuit, frequencies, *parameters,
constants=None, eval_string='', index=0):
""" recursive function that transforms a circuit, parameters, and
frequencies into a string that can be evaluated

Parameters
----------
circuit: str
frequencies: list/tuple/array of floats
parameters: list/tuple/array of floats
constants: dict

Returns
-------
eval_string: str
Python expression for calculating the resulting fit
index: int
Tracks parameter index through recursive calling of the function
"""

parameters = np.array(parameters).tolist()
frequencies = np.array(frequencies).tolist()
circuit = circuit.replace(' ', '')

def parse_circuit(circuit, parallel=False, series=False):
""" Splits a circuit string by either dashes (series) or commas
(parallel) outside of any paranthesis. Removes any leading 'p('
or trailing ')' when in parallel mode """

assert parallel != series, \
'Exactly one of parallel or series must be True'

def count_parens(string):
return string.count('('), string.count(')')

if parallel:
special = ','
if circuit.endswith(')') and circuit.startswith('p('):
circuit = circuit[2:-1]
if series:
special = '-'

split = circuit.split(special)

result = []
skipped = []
for i, sub_str in enumerate(split):
if i not in skipped:
if '(' not in sub_str and ')' not in sub_str:
result.append(sub_str)
else:
open_parens, closed_parens = count_parens(sub_str)
if open_parens == closed_parens:
result.append(sub_str)
else:
uneven = True
while i < len(split) - 1 and uneven:
sub_str += special + split[i+1]

open_parens, closed_parens = count_parens(sub_str)
uneven = open_parens != closed_parens

i += 1
skipped.append(i)
result.append(sub_str)
return result

parallel = parse_circuit(circuit, parallel=True)
series = parse_circuit(circuit, series=True)

if series is not None and len(series) > 1:
eval_string += "s(["
split = series
elif parallel is not None and len(parallel) > 1:
eval_string += "p(["
split = parallel
elif series == parallel:
eval_string += "(["
split = series

for i, elem in enumerate(split):
if ',' in elem or '-' in elem:
eval_string, index = buildCircuit(elem, frequencies,
*parameters,
constants=constants,
eval_string=eval_string,
index=index)
else:
param_string = ""
raw_elem = get_element_from_name(elem)
elem_number = check_and_eval(raw_elem).num_params
param_list = []
for j in range(elem_number):
if elem_number > 1:
current_elem = elem + '_{}'.format(j)
else:
current_elem = elem

if current_elem in constants.keys():
param_list.append(constants[current_elem])
else:
param_list.append(parameters[index])
index += 1

param_string += str(param_list)
new = raw_elem + '(' + param_string + ',' + str(frequencies) + ')'
eval_string += new

if i == len(split) - 1:
eval_string += '])'
else:
eval_string += ','

return eval_string, index

[docs]def extract_circuit_elements(circuit):
""" Extracts circuit elements from a circuit string.

Parameters
----------
circuit : str
Circuit string.

Returns
-------
extracted_elements : list
list of extracted elements.

"""
p_string = [x for x in circuit if x not in 'p(),-']
extracted_elements = []
current_element = []
length = len(p_string)
for i, char in enumerate(p_string):
if char not in ints:
current_element.append(char)
else:
# min to prevent looking ahead past end of list
if p_string[min(i+1, length-1)] not in ints:
current_element.append(char)
extracted_elements.append(''.join(current_element))
current_element = []
else:
current_element.append(char)
extracted_elements.append(''.join(current_element))
return extracted_elements

[docs]def calculateCircuitLength(circuit):
""" Calculates the number of elements in the circuit.

Parameters
----------
circuit : str
Circuit string.

Returns
-------
length : int
Length of circuit.

"""
length = 0
if circuit:
extracted_elements = extract_circuit_elements(circuit)
for elem in extracted_elements:
raw_element = get_element_from_name(elem)
num_params = check_and_eval(raw_element).num_params
length += num_params
return length

[docs]def check_and_eval(element):
""" Checks if an element is valid, then evaluates it.

Parameters
----------
element : str
Circuit element.

Raises
------
ValueError
Raised if an element is not in the list of allowed elements.

Returns
-------
Evaluated element.

"""
allowed_elements = circuit_elements.keys()
if element not in allowed_elements:
raise ValueError(f'{element} not in ' +
f'allowed elements ({allowed_elements})')
else:
return eval(element, circuit_elements)