mod_matrix_elements.f08 Source File


Contents


Source Code

module mod_matrix_elements
  use mod_global_variables, only: dp, NaN
  use mod_matrix_element_node, only: matrix_element_node_t, new_matrix_element_node
  use mod_basis_functions, only: basis_function
  use mod_state_vector, only: state_vector_t
  use mod_state_vector_component, only: sv_component_t
  use mod_get_indices, only: get_index
  use mod_logging, only: logger, str
  implicit none

  private

  type, public :: matrix_elements_t
    integer, private :: nb_elements
    type(matrix_element_node_t), pointer, private :: head
    type(matrix_element_node_t), pointer, private :: tail
    type(state_vector_t), pointer, private :: state_vector

  contains

    procedure, public :: add => add_node
    procedure, public :: get_node
    procedure, public :: get_elements
    procedure, public :: get_positions
    procedure, public :: get_nb_elements
    procedure, public :: delete
    procedure, private :: increment_nb_elements
  end type matrix_elements_t

  public ::  new_matrix_elements

contains

  function new_matrix_elements(state_vector) result(elements)
    type(state_vector_t), target, intent(in) :: state_vector
    type(matrix_elements_t) :: elements
    elements%nb_elements = 0
    elements%head => null()
    elements%tail => null()
    elements%state_vector => state_vector
  end function new_matrix_elements


  subroutine add_node(this, element, sv_comp1, sv_comp2, s1do, s2do)
    class(matrix_elements_t), intent(inout) :: this
    class(*), intent(in) :: element
    !> first state vector component
    type(sv_component_t), intent(in) :: sv_comp1
    !> second state vector component
    type(sv_component_t), intent(in) :: sv_comp2
    !> spline 1 derivative order, 1 = first derivative, 2 = second derivative, etc.
    integer, intent(in), optional :: s1do
    !> spline 2 derivative order, 1 = first derivative, 2 = second derivative, etc.
    integer, intent(in), optional :: s2do

    complex(dp) :: node_element
    integer :: position(2)
    procedure(basis_function), pointer :: spline1, spline2

    ! position checks
    position(1) = get_index( &
      name=sv_comp1%get_name(), array=this%state_vector%get_names() &
    )
    position(2) = get_index( &
      name=sv_comp2%get_name(), array=this%state_vector%get_names() &
    )
    if (.not. is_valid_position(position)) return

    node_element = cast_node_element_to_complex(element)
    if (is_NaN_element(node_element, sv_comp1, sv_comp2)) return
    if (is_inf_element(node_element, sv_comp1, sv_comp2)) return

    call sv_comp1%get_spline_function(spline_order=s1do, spline_func=spline1)
    call sv_comp2%get_spline_function(spline_order=s2do, spline_func=spline2)
    ! if head is not associated, then the list is empty and we create the first node
    if (.not. associated(this%head)) then
      allocate( &
        this%head, &
        source=new_matrix_element_node(node_element, position, spline1, spline2) &
      )
      this%tail => this%head
    else
      allocate( &
        this%tail%next, &
        source=new_matrix_element_node(node_element, position, spline1, spline2) &
      )
      this%tail => this%tail%next
    end if
    call this%increment_nb_elements()
  end subroutine add_node


  function get_node(this, inode) result(node)
    class(matrix_elements_t), intent(in) :: this
    integer, intent(in) :: inode
    type(matrix_element_node_t), pointer :: node
    type(matrix_element_node_t), pointer :: current_node
    integer :: i

    node => null()
    if (inode < 1 .or. inode > this%nb_elements) then
      call logger%error("get_node: inode out of range: " // str(inode))
      return
    end if
    current_node => this%head
    do i = 1, inode - 1
      current_node => current_node%next
    end do
    node => current_node
    nullify(current_node)
  end function get_node


  function get_elements(this) result(elements)
    class(matrix_elements_t), intent(in) :: this
    complex(dp) :: elements(this%nb_elements)
    type(matrix_element_node_t), pointer :: current_node
    integer :: i

    current_node => this%head
    do i = 1, this%nb_elements
      elements(i) = current_node%get_element()
      current_node => current_node%next
    end do
    nullify(current_node)
  end function get_elements


  function get_positions(this) result(positions)
    class(matrix_elements_t), intent(in) :: this
    integer :: positions(this%nb_elements, 2)
    type(matrix_element_node_t), pointer :: current_node
    integer :: i

    current_node => this%head
    do i = 1, this%nb_elements
      positions(i, :) = current_node%get_position()
      current_node => current_node%next
    end do
    nullify(current_node)
  end function get_positions


  pure integer function get_nb_elements(this)
    class(matrix_elements_t), intent(in) :: this
    get_nb_elements = this%nb_elements
  end function get_nb_elements


  pure function cast_node_element_to_complex(element) result(node_element)
    class(*), intent(in) :: element
    complex(dp) :: node_element

    select type(element)
      type is (complex(dp))
        node_element = element
      type is (real(dp))
        node_element = cmplx(element, 0.0_dp, kind=dp)
      type is (integer)
        node_element = cmplx(element, 0.0_dp, kind=dp)
      class default
        node_element = cmplx(NaN, NaN, kind=dp)
    end select
  end function cast_node_element_to_complex


  pure logical function is_valid_position(position)
    integer, intent(in) :: position(:)
    is_valid_position = all(position > 0)
  end function is_valid_position


  logical function is_NaN_element(element, sv_comp1, sv_comp2)
    use mod_check_values, only: is_NaN

    complex(dp), intent(in) :: element
    type(sv_component_t), intent(in) :: sv_comp1
    type(sv_component_t), intent(in) :: sv_comp2

    is_NaN_element = is_NaN(element)
    if (is_NaN_element) call logger%error( &
      "NaN matrix element encountered for [" &
      // trim(sv_comp1%get_name()) // ", " // trim(sv_comp2%get_name()) // "]" &
    )
  end function is_NaN_element


  logical function is_inf_element(element, sv_comp1, sv_comp2)
    use mod_check_values, only: is_infinite

    complex(dp), intent(in) :: element
    type(sv_component_t), intent(in) :: sv_comp1
    type(sv_component_t), intent(in) :: sv_comp2

    is_inf_element = is_infinite(element)
    if (is_inf_element) call logger%error( &
      "Infinity encountered in matrix element for [" &
      // trim(sv_comp1%get_name()) // ", " // trim(sv_comp2%get_name()) // "]" &
    )
  end function is_inf_element


  pure subroutine increment_nb_elements(this)
    class(matrix_elements_t), intent(inout) :: this
    this%nb_elements = this%nb_elements + 1
  end subroutine increment_nb_elements


  pure subroutine delete(this)
    class(matrix_elements_t), intent(inout) :: this
    type(matrix_element_node_t), pointer :: current_node, next_node

    if (.not. associated(this%head)) return

    current_node => this%head
    next_node => current_node%next
    do while (associated(current_node))
      call current_node%delete()
      deallocate(current_node)
      current_node => next_node
      if (associated(current_node)) next_node => current_node%next
    end do
    nullify(current_node)
    nullify(next_node)
    nullify(this%head)
    nullify(this%tail)
    nullify(this%state_vector)
    this%nb_elements = 0
  end subroutine delete

end module mod_matrix_elements