Source code for felupe.constitution.jax._helpers

# -*- coding: utf-8 -*-
"""
This file is part of FElupe.

FElupe is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

FElupe is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with FElupe.  If not, see <http://www.gnu.org/licenses/>.
"""
import inspect
from functools import wraps

import jax
import jax.numpy as jnp
from jax.numpy.linalg import det


[docs] def vmap(fun, in_axes=0, out_axes=0, method=jax.vmap, **kwargs): """Vectorizing map. Creates a function which maps ``fun`` over argument axes. This decorator treats all non-specified arguments and keyword-arguments as static. See Also -------- jax.vmap : Vectorizing map. Creates a function which maps ``fun`` over argument axes. """ @wraps(fun) def vmap_with_static_kwargs(*args, **keywordargs): # sorted list of all parameter keys, including kwargs with default values sig = inspect.signature(fun) keys = [ key for key, value in sig.parameters.items() if not (key in ["args", "kwargs"] and value.default == inspect._empty) ] if not ( "kwargs" in sig.parameters.keys() and sig.parameters["kwargs"].default == inspect._empty ): # check if unexpected keyword-argument is given for key in keywordargs.keys(): if key not in keys: raise TypeError( f"{fun.__name__}() got an unexpected keyword argument '{key}'" ) # dict with default values for all parameters parameters = dict( [(key, value.default) for key, value in sig.parameters.items()] ) # merge dict of default values with custom keyword arguments items = {**parameters, **keywordargs} # create sorted list of values of keyword-arguments, including default kwargs keyword_args = [items[key] for key in keys[len(args) :]] # don't map non-given arguments and keyword-arguments if not hasattr(in_axes, "__len__"): in_axes_tuple = (in_axes,) else: in_axes_tuple = in_axes static_argnums = len(args) + len(keyword_args) - len(in_axes_tuple) in_axes_new = (*in_axes_tuple, *([None] * static_argnums)) vfun = method(fun, in_axes=in_axes_new, out_axes=out_axes, **kwargs) return vfun(*args, *keyword_args) return vmap_with_static_kwargs
def vmap2(fun, in_axes=[0, 0], out_axes=[0, 0], methods=[jax.vmap, jax.vmap], **kwargs): "Nested vectorizing map." return vmap( vmap( fun, in_axes=in_axes[0], out_axes=out_axes[0], method=methods[0], **kwargs ), in_axes=in_axes[1], out_axes=out_axes[1], method=methods[1], **kwargs, ) def as_total_lagrange(fun): @wraps(fun) def evaluate(F, *args, **kwargs): i, j = jnp.triu_indices(3) C_triu = jnp.einsum("ia,ia->a", F[:, i], F[:, j]) C = C_triu[jnp.array([[0, 1, 2], [1, 3, 4], [2, 4, 5]])] return fun(C, *args, **kwargs) return evaluate def isochoric_volumetric_split(fun): """Apply the material formulation only on the isochoric part of the multiplicative split of the deformation gradient.""" @wraps(fun) def apply_iso(C, *args, **kwargs): return fun(det(C) ** (-1 / 3) * C, *args, **kwargs) return apply_iso