Source code for felupe.constitution.jax._material

# -*- coding: utf-8 -*-
"""
This file is part of FElupe.

FElupe is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

FElupe is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with FElupe.  If not, see <http://www.gnu.org/licenses/>.
"""

import warnings

import jax
import numpy as np

from .._material import Material as MaterialDefault
from ._helpers import vmap2


[docs] class Material(MaterialDefault): r"""A material definition with a given function for the partial derivative of the strain energy function w.r.t. the deformation gradient tensor with Automatic Differentiation provided by :mod:`jax`. Parameters ---------- fun : callable A gradient of the strain energy density function w.r.t. the deformation gradient tensor :math:`\boldsymbol{F}`. Function signature must be ``fun = lambda F, **kwargs: P`` for functions without state variables and ``fun = lambda F, statevars, **kwargs: [P, statevars_new]`` for functions with state variables. It is important to use only differentiable math-functions from :mod:`jax`. nstatevars : int, optional Number of state variables (default is 0). jit : bool, optional A flag to invoke just-in-time compilation (default is True). parallel : bool, optional A flag to invoke parallel function evaluations (default is False). If True, the quadrature points are executed in parallel. The number of devices must be greater or equal the number of quadrature points per cell. jacobian : callable or None, optional A callable for the Jacobian. Default is None, where :func:`jax.jacobian` is used. This may be used to switch to forward-mode differentian :func:`jax.jacfwd`. **kwargs : dict, optional Optional keyword-arguments for the gradient of the strain energy density function. Notes ----- The gradient of the strain energy density function :math:`\frac{\partial \psi}{\partial \boldsymbol{F}}` must be given in terms of the deformation gradient tensor :math:`\boldsymbol{F}`. .. warning:: It is important to use only differentiable math-functions from :mod:`jax`! Take this code-block as template .. code-block:: import felupe as fem import felupe.constitution.jax as mat import jax.numpy as jnp def neo_hooke(F, mu): "First Piola-Kirchhoff stress of the Neo-Hookean material formulation." C = F.T @ F Cu = jnp.linalg.det(C) ** (-1/3) * C dev = lambda C: C - jnp.trace(C) / 3 * jnp.eye(3) return mu * F @ dev(Cu) @ jnp.linalg.inv(C) umat = mat.Material(neo_hooke, mu=1) and this code-block for material formulations with state variables: .. code-block:: import felupe as fem import felupe.constitution.jax as mat import jax.numpy as jnp def viscoelastic(F, Cin, mu, eta, dtime): "Finite strain viscoelastic material formulation." # unimodular part of the right Cauchy-Green deformation tensor C = F.T @ F Cu = jnp.linalg.det(C) ** (-1 / 3) * C # update of state variables by evolution equation from_triu = lambda C: C[jnp.array([[0, 1, 2], [1, 3, 4], [2, 4, 5]])] Ci = from_triu(Cin) + mu / eta * dtime * Cu Ci = jnp.linalg.det(Ci) ** (-1 / 3) * Ci # second Piola-Kirchhoff stress tensor dev = lambda C: C - jnp.trace(C) / 3 * jnp.eye(3) S = mu * dev(Cu @ jnp.linalg.inv(Ci)) @ jnp.linalg.inv(C) # first Piola-Kirchhoff stress tensor and state variable i, j = jnp.triu_indices(3) to_triu = lambda C: C[i, j] return F @ S, to_triu(Ci) umat = mat.Material(viscoelastic, mu=1, eta=1, dtime=1, nstatevars=6) .. note:: See the `documentation of JAX <https://jax.readthedocs.io>`_ for further details. JAX uses single-precision (32bit) data types by default. This requires to relax the tolerance of :func:`~felupe.newtonrhapson` to ``tol=1e-4``. If required, JAX may be enforced to use double-precision at startup with ``jax.config.update("jax_enable_x64", True)``. Examples -------- View force-stretch curves on elementary incompressible deformations. .. pyvista-plot:: :context: >>> import felupe as fem >>> import felupe.constitution.jax as mat >>> import jax.numpy as jnp >>> >>> def neo_hooke(F, mu): ... "First Piola-Kirchhoff stress of the Neo-Hookean material formulation." ... ... C = F.T @ F ... Cu = jnp.linalg.det(C) ** (-1/3) * C ... dev = lambda C: C - jnp.trace(C) / 3 * jnp.eye(3) ... ... return mu * F @ dev(Cu) @ jnp.linalg.inv(C) >>> >>> umat = mat.Material(neo_hooke, mu=1) >>> ax = umat.plot(incompressible=True) .. pyvista-plot:: :include-source: False :context: :force_static: >>> import pyvista as pv >>> >>> fig = ax.get_figure() >>> chart = pv.ChartMPL(fig) >>> chart.show() """ def __init__( self, fun, nstatevars=0, jit=True, parallel=False, jacobian=None, **kwargs ): has_aux = nstatevars > 0 self.fun = fun if jacobian is None: jacobian = jax.jacobian keyword_args = kwargs if hasattr(fun, "kwargs"): keyword_args = {**fun.kwargs, **keyword_args} super().__init__( stress=self._stress, elasticity=self._elasticity, nstatevars=nstatevars, **keyword_args, ) in_axes = out_axes_grad = [2, 3] if nstatevars > 0: in_axes = out_axes_grad = [(2, 1), (3, 2)] out_axes_hess = [4, 5] if nstatevars > 0: out_axes_hess = [(4, 1), (5, 2)] methods = [jax.vmap, jax.vmap] if parallel: methods[0] = jax.pmap # apply on quadrature-points jit = False # pmap uses jit self._grad = vmap2( self.fun, in_axes=in_axes, out_axes=out_axes_grad, methods=methods ) self._hess = vmap2( jacobian(self.fun, has_aux=has_aux), in_axes=in_axes, out_axes=out_axes_hess, methods=methods, ) if jit: self._grad = jax.jit(self._grad) self._hess = jax.jit(self._hess) def _stress(self, x, **kwargs): if self.nstatevars > 0: statevars = x[1] F = x[0] if self.nstatevars > 0: dWdF, statevars_new = self._grad(F, statevars, **kwargs) statevars_new = np.array(statevars_new) else: dWdF = self._grad(F, **kwargs) statevars_new = None return [np.array(dWdF), statevars_new] def _elasticity(self, x, **kwargs): if self.nstatevars > 0: statevars = x[1] F = x[0] if self.nstatevars > 0: d2WdFdF, statevars_new = self._hess(F, statevars, **kwargs) else: d2WdFdF = self._hess(F, **kwargs) return [np.array(d2WdFdF)]