46 #ifndef MUELU_SCHURCOMPLEMENTFACTORY_DEF_HPP_ 47 #define MUELU_SCHURCOMPLEMENTFACTORY_DEF_HPP_ 49 #include <Xpetra_BlockedCrsMatrix.hpp> 50 #include <Xpetra_MultiVectorFactory.hpp> 51 #include <Xpetra_VectorFactory.hpp> 52 #include <Xpetra_MatrixFactory.hpp> 53 #include <Xpetra_Matrix.hpp> 54 #include <Xpetra_MatrixMatrix.hpp> 55 #include <Xpetra_TripleMatrixMultiply.hpp> 56 #include <Xpetra_CrsMatrixWrap.hpp> 57 #include <Xpetra_BlockedCrsMatrix.hpp> 58 #include <Xpetra_CrsMatrix.hpp> 62 #include "MueLu_Utilities.hpp" 63 #include "MueLu_SchurComplementFactory.hpp" 68 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
70 RCP<ParameterList> validParamList = rcp(
new ParameterList());
72 const SC one = Teuchos::ScalarTraits<SC>::one();
74 validParamList->set<RCP<const FactoryBase> >(
"A" ,
NoFactory::getRCP(),
"Generating factory of the matrix A used for building Schur complement (must be a 2x2 blocked operator)");
75 validParamList->set<RCP<const FactoryBase> >(
"Ainv" , Teuchos::null,
"Generating factory of the inverse matrix used in the Schur complement");
77 validParamList->set<SC> (
"omega", one,
"Scaling parameter in S = A(1,1) - 1/omega A(1,0) Ainv A(0,1)");
79 return validParamList;
82 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
84 Input(currentLevel,
"A");
87 RCP<const FactoryBase> AinvFact = GetFactory(
"Ainv");
91 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
95 RCP<Matrix> A = Get<RCP<Matrix> >(currentLevel,
"A");
96 RCP<BlockedCrsMatrix> bA = rcp_dynamic_cast<BlockedCrsMatrix>(A);
99 "MueLu::SchurComplementFactory::Build: input matrix A is not of type BlockedCrsMatrix!");
101 "MueLu::SchurComplementFactory::Build: input matrix A is a " << bA->Rows() <<
"x" << bA->Cols() <<
" block matrix. We expect a 2x2 blocked operator.");
104 RCP<Matrix> Ainv = currentLevel.Get<RCP<Matrix> >(
"Ainv", this->GetFactory(
"Ainv").get());
105 RCP<Matrix> S = ComputeSchurComplement(bA, Ainv);
107 GetOStream(
Statistics1) <<
"S has " << S->getGlobalNumRows() <<
"x" << S->getGlobalNumCols() <<
" rows and columns." << std::endl;
111 Set(currentLevel,
"A", S);
114 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
115 RCP<Xpetra::Matrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>>
118 using STS = Teuchos::ScalarTraits<SC>;
119 const SC zero = STS::zero(), one = STS::one();
121 RCP<Matrix> A01 = bA->getMatrix(0,1);
122 RCP<Matrix> A10 = bA->getMatrix(1,0);
123 RCP<Matrix> A11 = bA->getMatrix(1,1);
125 RCP<BlockedCrsMatrix> bA01 = Teuchos::rcp_dynamic_cast<BlockedCrsMatrix>(A01);
126 const bool isBlocked = (bA01 == Teuchos::null ? false :
true);
128 const ParameterList& pL = GetParameterList();
129 const SC omega = pL.get<
Scalar>(
"omega");
132 "MueLu::SchurComplementFactory::Build: Scaling parameter omega must not be zero to avoid division by zero.");
134 RCP<Matrix> S = Teuchos::null;
135 RCP<Matrix> D = Teuchos::null;
138 if(A01.is_null() ==
false && A10.is_null() ==
false) {
140 Ainv->scale(Teuchos::as<Scalar>(-one/omega));
144 RCP<ParameterList> myparams = rcp(
new ParameterList);
145 myparams->set(
"compute global constants",
true);
148 TEUCHOS_TEST_FOR_EXCEPTION(A01->getRangeMap()->isSameAs(*(Ainv->getDomainMap())) ==
false,
Exceptions::RuntimeError,
149 "MueLu::SchurComplementFactory::Build: RangeMap of A01 and domain map of Ainv are not the same.");
150 RCP<Matrix> C = MatrixMatrix::Multiply(*Ainv,
false, *A01,
false, GetOStream(
Statistics2),
true,
true, std::string(
"SchurComplementFactory"), myparams);
154 "MueLu::SchurComplementFactory::Build: RangeMap of A10 and domain map A01 are not the same.");
155 D = MatrixMatrix::Multiply(*A10,
false, *C,
false, GetOStream(
Statistics2),
true,
true, std::string(
"SchurComplementFactory"), myparams);
159 auto bA10 = Teuchos::rcp_dynamic_cast<BlockedCrsMatrix>(A10);
160 auto bAinv = Teuchos::rcp_dynamic_cast<BlockedCrsMatrix>(Ainv);
162 "MueLu::SchurComplementFactory::Build: Casting Ainv to BlockedCrsMatrix not possible.");
166 "MueLu::SchurComplementFactory::Build: Block rows and cols of bA01 and bAinv are not compatible.");
167 RCP<BlockedCrsMatrix> C = MatrixMatrix::TwoMatrixMultiplyBlock(*bAinv,
false, *bA01,
false, GetOStream(
Statistics2));
171 "MueLu::SchurComplementFactory::Build: Block rows and cols of bA10 and bA01 are not compatible.");
172 D = MatrixMatrix::TwoMatrixMultiplyBlock(*bA10,
false, *C,
false, GetOStream(
Statistics2));
174 if (!A11.is_null()) {
175 MatrixMatrix::TwoMatrixAdd(*A11,
false, one, *D,
false, one, S, GetOStream(
Statistics2));
179 "MueLu::SchurComplementFactory::Build: RangeMap of A11 and S are not the same.");
181 "MueLu::SchurComplementFactory::Build: DomainMap of A11 and S are not the same.");
184 S = MatrixFactory::BuildCopy(D);
188 if (!A11.is_null()) {
189 S = MatrixFactory::BuildCopy(A11);
191 S = MatrixFactory::Build(A11->getRowMap(), 10 );
192 S->fillComplete(A11->getDomainMap(),A11->getRangeMap());
203 RCP<BlockedCrsMatrix> bS = Teuchos::rcp_dynamic_cast<BlockedCrsMatrix>(S);
205 if (bS != Teuchos::null && bS->Rows() == 1 && bS->Cols() == 1) {
206 RCP<Matrix> temp = bS->getCrsMatrix();
Exception indicating invalid cast attempted.
Timer to be used in factories. Similar to Monitor but with additional timers.
void DeclareInput(Level ¤tLevel) const
Input.
Namespace for MueLu classes and methods.
Print even more statistics.
MueLu::DefaultScalar Scalar
Class that holds all level-specific information.
RCP< const ParameterList > GetValidParameterList() const
Return a const parameter list of valid parameters that setParameterList() will accept.
RCP< Matrix > ComputeSchurComplement(RCP< BlockedCrsMatrix > &bA, RCP< Matrix > &Ainv) const
Schur complement calculation method.
void Build(Level ¤tLevel) const
Build an object with this factory.
Exception throws to report errors in the internal logical of the program.
void DeclareInput(const std::string &ename, const FactoryBase *factory, const FactoryBase *requestedBy=NoFactory::get())
Callback from FactoryBase::CallDeclareInput() and FactoryBase::DeclareInput()
static const RCP< const NoFactory > getRCP()
Static Get() functions.