LLTMatrix, LUscalarMatrix, QRMatrix: Provided consistent 'solve' interface

This commit is contained in:
Henry Weller 2016-03-24 19:13:04 +00:00
parent f019d1e738
commit d4046bb85e
7 changed files with 93 additions and 73 deletions

View File

@ -24,6 +24,7 @@ License
\*---------------------------------------------------------------------------*/
#include "scalarMatrices.H"
#include "LUscalarMatrix.H"
#include "LLTMatrix.H"
#include "QRMatrix.H"
#include "vector.H"
@ -113,70 +114,53 @@ int main(int argc, char *argv[])
Info<< "Solution = " << rhs << endl;
}
scalarSquareMatrix squareMatrix(3, Zero);
squareMatrix(0, 0) = 4;
squareMatrix(0, 1) = 12;
squareMatrix(0, 2) = -16;
squareMatrix(1, 0) = 12;
squareMatrix(1, 1) = 37;
squareMatrix(1, 2) = -43;
squareMatrix(2, 0) = -16;
squareMatrix(2, 1) = -43;
squareMatrix(2, 2) = 98;
Info<< nl << "Square Matrix = " << squareMatrix << endl;
const scalarField source(3, 1);
{
scalarSquareMatrix squareMatrix(3, Zero);
squareMatrix(0, 0) = 4;
squareMatrix(0, 1) = 12;
squareMatrix(0, 2) = -16;
squareMatrix(1, 0) = 12;
squareMatrix(1, 1) = 37;
squareMatrix(1, 2) = -43;
squareMatrix(2, 0) = -16;
squareMatrix(2, 1) = -43;
squareMatrix(2, 2) = 98;
const scalarSquareMatrix squareMatrixCopy = squareMatrix;
Info<< nl << "Square Matrix = " << squareMatrix << endl;
Info<< "det = " << det(squareMatrixCopy) << endl;
{
scalarSquareMatrix sm(squareMatrix);
Info<< "det = " << det(sm) << endl;
}
scalarSquareMatrix sm(squareMatrix);
labelList rhs(3, 0);
label sign;
LUDecompose(squareMatrix, rhs, sign);
LUDecompose(sm, rhs, sign);
Info<< "Decomposition = " << squareMatrix << endl;
Info<< "Decomposition = " << sm << endl;
Info<< "Pivots = " << rhs << endl;
Info<< "Sign = " << sign << endl;
Info<< "det = " << detDecomposed(squareMatrix, sign) << endl;
Info<< "det = " << detDecomposed(sm, sign) << endl;
}
{
scalarSquareMatrix squareMatrix(3, Zero);
squareMatrix(0, 0) = 4;
squareMatrix(0, 1) = 12;
squareMatrix(0, 2) = -16;
squareMatrix(1, 0) = 12;
squareMatrix(1, 1) = 37;
squareMatrix(1, 2) = -43;
squareMatrix(2, 0) = -16;
squareMatrix(2, 1) = -43;
squareMatrix(2, 2) = 98;
scalarField source(3, 1);
LUscalarMatrix LU(squareMatrix);
scalarField x((LU.solve(source));
Info<< "LU solve residual " << (squareMatrix*x - source) << endl;
}
{
LLTMatrix<scalar> LLT(squareMatrix);
scalarField x(LLT.solve(source));
Info<< "LLT solve residual " << (squareMatrix*x - source) << endl;
}
{
scalarSquareMatrix squareMatrix(3, Zero);
squareMatrix(0, 0) = 4;
squareMatrix(0, 1) = 12;
squareMatrix(0, 2) = -16;
squareMatrix(1, 0) = 12;
squareMatrix(1, 1) = 37;
squareMatrix(1, 2) = -43;
squareMatrix(2, 0) = -16;
squareMatrix(2, 1) = -43;
squareMatrix(2, 2) = 98;
scalarField source(3, 1);
QRMatrix<scalarSquareMatrix> QR(squareMatrix);
scalarField x(QR.solve(source));
@ -184,8 +168,7 @@ int main(int argc, char *argv[])
<< (squareMatrix*x - source) << endl;
Info<< "QR inverse solve residual "
<< (x - QR.inverse()*source) << endl;
<< (x - QR.inv()*source) << endl;
}
Info<< "\nEnd\n" << endl;

View File

@ -95,6 +95,12 @@ void Foam::LLTMatrix<Type>::solve
const Field<Type>& source
) const
{
// If x and source are different initialize x = source
if (&x != &source)
{
x = source;
}
const SquareMatrix<Type>& LLT = *this;
const label m = LLT.m();

View File

@ -116,10 +116,16 @@ public:
//- Perform the LU decomposition of the matrix M
void decompose(const scalarSquareMatrix& M);
//- Solve the matrix using the LU decomposition with pivoting
// returning the solution in the source
template<class T>
void solve(Field<T>& source) const;
//- Solve the linear system with the given source
// and returning the solution in the Field argument x.
// This function may be called with the same field for x and source.
template<class Type>
void solve(Field<Type>& x, const Field<Type>& source) const;
//- Solve the linear system with the given source
// returning the solution
template<class Type>
tmp<Field<Type>> solve(const Field<Type>& source) const;
};

View File

@ -24,23 +24,34 @@ License
\*---------------------------------------------------------------------------*/
#include "LUscalarMatrix.H"
#include "SubField.H"
// * * * * * * * * * * * * * * * Member Functions * * * * * * * * * * * * * //
template<class Type>
void Foam::LUscalarMatrix::solve(Field<Type>& sourceSol) const
void Foam::LUscalarMatrix::solve
(
Field<Type>& x,
const Field<Type>& source
) const
{
// If x and source are different initialize x = source
if (&x != &source)
{
x = source;
}
if (Pstream::parRun())
{
Field<Type> completeSourceSol(m());
Field<Type> X(m());
if (Pstream::master(comm_))
{
typename Field<Type>::subField
(
completeSourceSol,
sourceSol.size()
).assign(sourceSol);
X,
x.size()
).assign(x);
for
(
@ -55,7 +66,7 @@ void Foam::LUscalarMatrix::solve(Field<Type>& sourceSol) const
slave,
reinterpret_cast<char*>
(
&(completeSourceSol[procOffsets_[slave]])
&(X[procOffsets_[slave]])
),
(procOffsets_[slave+1]-procOffsets_[slave])*sizeof(Type),
Pstream::msgType(),
@ -69,8 +80,8 @@ void Foam::LUscalarMatrix::solve(Field<Type>& sourceSol) const
(
Pstream::scheduled,
Pstream::masterNo(),
reinterpret_cast<const char*>(sourceSol.begin()),
sourceSol.byteSize(),
reinterpret_cast<const char*>(x.begin()),
x.byteSize(),
Pstream::msgType(),
comm_
);
@ -78,12 +89,12 @@ void Foam::LUscalarMatrix::solve(Field<Type>& sourceSol) const
if (Pstream::master(comm_))
{
LUBacksubstitute(*this, pivotIndices_, completeSourceSol);
LUBacksubstitute(*this, pivotIndices_, X);
sourceSol = typename Field<Type>::subField
x = typename Field<Type>::subField
(
completeSourceSol,
sourceSol.size()
X,
x.size()
);
for
@ -99,7 +110,7 @@ void Foam::LUscalarMatrix::solve(Field<Type>& sourceSol) const
slave,
reinterpret_cast<const char*>
(
&(completeSourceSol[procOffsets_[slave]])
&(X[procOffsets_[slave]])
),
(procOffsets_[slave + 1]-procOffsets_[slave])*sizeof(Type),
Pstream::msgType(),
@ -113,8 +124,8 @@ void Foam::LUscalarMatrix::solve(Field<Type>& sourceSol) const
(
Pstream::scheduled,
Pstream::masterNo(),
reinterpret_cast<char*>(sourceSol.begin()),
sourceSol.byteSize(),
reinterpret_cast<char*>(x.begin()),
x.byteSize(),
Pstream::msgType(),
comm_
);
@ -122,9 +133,24 @@ void Foam::LUscalarMatrix::solve(Field<Type>& sourceSol) const
}
else
{
LUBacksubstitute(*this, pivotIndices_, sourceSol);
LUBacksubstitute(*this, pivotIndices_, x);
}
}
template<class Type>
Foam::tmp<Foam::Field<Type>> Foam::LUscalarMatrix::solve
(
const Field<Type>& source
) const
{
tmp<Field<Type>> tx(new Field<Type>(m()));
Field<Type>& x = tx.ref();
solve(x, source);
return tx;
}
// ************************************************************************* //

View File

@ -225,7 +225,7 @@ Foam::QRMatrix<MatrixType>::solve
template<class MatrixType>
typename Foam::QRMatrix<MatrixType>::QMatrixType
Foam::QRMatrix<MatrixType>::inverse() const
Foam::QRMatrix<MatrixType>::inv() const
{
const label m = Q_.m();

View File

@ -108,7 +108,7 @@ public:
tmp<Field<cmptType>> solve(const Field<cmptType>& source) const;
//- Return the inverse of a square matrix
QMatrixType inverse() const;
QMatrixType inv() const;
};

View File

@ -533,8 +533,7 @@ void Foam::GAMGSolver::solveCoarsestLevel
if (directSolveCoarsest_)
{
coarsestCorrField = coarsestSource;
coarsestLUMatrixPtr_->solve(coarsestCorrField);
coarsestLUMatrixPtr_->solve(coarsestCorrField, coarsestSource);
}
//else if
//(