42 #ifndef TPETRA_DETAILS_EXTRACTBLOCKDIAGONAL_HPP 43 #define TPETRA_DETAILS_EXTRACTBLOCKDIAGONAL_HPP 45 #include "TpetraCore_config.h" 46 #include "Tpetra_CrsMatrix.hpp" 47 #include "Teuchos_RCP.hpp" 61 template<
class SparseMatrixType,
62 class MultiVectorType>
63 void extractBlockDiagonal(
const SparseMatrixType& A, MultiVectorType & diagonal) {
64 using local_map_type =
typename SparseMatrixType::map_type::local_map_type;
65 using SC =
typename MultiVectorType::scalar_type;
66 using LO =
typename SparseMatrixType::local_ordinal_type;
67 using KCRS =
typename SparseMatrixType::local_matrix_device_type;
68 using lno_view_t =
typename KCRS::StaticCrsGraphType::row_map_type::const_type;
69 using lno_nnz_view_t =
typename KCRS::StaticCrsGraphType::entries_type::const_type;
70 using scalar_view_t =
typename KCRS::values_type::const_type;
71 using local_mv_type =
typename MultiVectorType::dual_view_type::t_dev;
72 using range_type = Kokkos::RangePolicy<typename SparseMatrixType::node_type::execution_space, LO>;
76 TEUCHOS_TEST_FOR_EXCEPTION(!A.getRowMap()->isSameAs(*diagonal.getMap()),
77 std::runtime_error,
"Tpetra::Details::extractBlockDiagonal was given incompatible maps");
80 LO numrows = diagonal.getLocalLength();
81 LO blocksize = diagonal.getNumVectors();
82 SC
ZERO = Teuchos::ScalarTraits<typename MultiVectorType::scalar_type>::zero();
85 local_map_type rowmap = A.getRowMap()->getLocalMap();
86 local_map_type colmap = A.getRowMap()->getLocalMap();
87 local_mv_type diag = diagonal.getLocalViewDevice(Access::OverwriteAll);
88 const KCRS Amat = A.getLocalMatrixDevice();
89 lno_view_t Arowptr = Amat.graph.row_map;
90 lno_nnz_view_t Acolind = Amat.graph.entries;
91 scalar_view_t Avals = Amat.values;
93 Kokkos::parallel_for(
"Tpetra::extractBlockDiagonal",range_type(0,numrows),KOKKOS_LAMBDA(
const LO i){
94 LO diag_col = colmap.getLocalElement(rowmap.getGlobalElement(i));
95 LO blockStart = diag_col - (diag_col % blocksize);
96 LO blockStop = blockStart + blocksize;
97 for(LO k=0; k<blocksize; k++)
100 for (
size_t k = Arowptr(i); k < Arowptr(i+1); k++) {
102 if (blockStart <= col && col < blockStop) {
103 diag(i,col-blockStart) = Avals(k);
112 #endif // TPETRA_DETAILS_EXTRACTBLOCKDIAGONAL_HPP Namespace Tpetra contains the class and methods constituting the Tpetra library.
static bool debug()
Whether Tpetra is in debug mode.
Implementation details of Tpetra.
Replace old values with zero.
Declaration of Tpetra::Details::Behavior, a class that describes Tpetra's behavior.