18949adfdSHong Zhang 28949adfdSHong Zhang /* 38949adfdSHong Zhang Defines matrix-matrix product routines for pairs of MPIAIJ matrices 48949adfdSHong Zhang C = A^T * B 58949adfdSHong Zhang The routines are slightly modified from MatTransposeMatMultxxx_SeqAIJ_SeqDense(). 68949adfdSHong Zhang */ 78949adfdSHong Zhang #include <../src/mat/impls/aij/seq/aij.h> /*I "petscmat.h" I*/ 88949adfdSHong Zhang #include <../src/mat/impls/aij/mpi/mpiaij.h> 98949adfdSHong Zhang #include <../src/mat/impls/dense/mpi/mpidense.h> 108949adfdSHong Zhang 11d71ae5a4SJacob Faibussowitsch PetscErrorCode MatDestroy_MPIDense_MatTransMatMult(void *data) 12d71ae5a4SJacob Faibussowitsch { 136718818eSStefano Zampini Mat_MatTransMatMult *atb = (Mat_MatTransMatMult *)data; 148949adfdSHong Zhang 158949adfdSHong Zhang PetscFunctionBegin; 169566063dSJacob Faibussowitsch PetscCall(MatDestroy(&atb->mA)); 179566063dSJacob Faibussowitsch PetscCall(VecDestroy(&atb->bt)); 189566063dSJacob Faibussowitsch PetscCall(VecDestroy(&atb->ct)); 199566063dSJacob Faibussowitsch PetscCall(PetscFree(atb)); 20*3ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 218949adfdSHong Zhang } 228949adfdSHong Zhang 236718818eSStefano Zampini static PetscErrorCode MatTransposeMatMultNumeric_MPIAIJ_MPIDense(Mat, Mat, Mat); 246718818eSStefano Zampini 25d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatTransposeMatMultSymbolic_MPIAIJ_MPIDense(Mat A, Mat B, PetscReal fill, Mat C) 26d71ae5a4SJacob Faibussowitsch { 278949adfdSHong Zhang Mat_MatTransMatMult *atb; 286718818eSStefano Zampini PetscBool cisdense; 298949adfdSHong Zhang 308949adfdSHong Zhang PetscFunctionBegin; 316718818eSStefano Zampini MatCheckProduct(C, 4); 3228b400f6SJacob Faibussowitsch PetscCheck(!C->product->data, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Extra product struct not empty"); 338949adfdSHong Zhang 348949adfdSHong Zhang /* create output dense matrix C = A^T*B */ 359566063dSJacob Faibussowitsch PetscCall(MatSetSizes(C, A->cmap->n, B->cmap->n, A->cmap->N, B->cmap->N)); 369566063dSJacob Faibussowitsch PetscCall(PetscObjectTypeCompareAny((PetscObject)C, &cisdense, MATMPIDENSE, MATMPIDENSECUDA, "")); 3748a46eb9SPierre Jolivet if (!cisdense) PetscCall(MatSetType(C, ((PetscObject)B)->type_name)); 389566063dSJacob Faibussowitsch PetscCall(MatSetUp(C)); 398949adfdSHong Zhang 406718818eSStefano Zampini /* create additional data structure for the product */ 419566063dSJacob Faibussowitsch PetscCall(PetscNew(&atb)); 426718818eSStefano Zampini if (B->cmap->N) { 439566063dSJacob Faibussowitsch PetscCall(MatCreateMAIJ(A, B->cmap->N, &atb->mA)); 44445ca090SPierre Jolivet if (!atb->mA->assembled) { 459566063dSJacob Faibussowitsch PetscCall(MatAssemblyBegin(atb->mA, MAT_FINAL_ASSEMBLY)); 469566063dSJacob Faibussowitsch PetscCall(MatAssemblyEnd(atb->mA, MAT_FINAL_ASSEMBLY)); 47445ca090SPierre Jolivet } 489566063dSJacob Faibussowitsch PetscCall(MatCreateVecs(atb->mA, &atb->ct, &atb->bt)); 496718818eSStefano Zampini } 506718818eSStefano Zampini C->product->data = atb; 516718818eSStefano Zampini C->product->destroy = MatDestroy_MPIDense_MatTransMatMult; 528949adfdSHong Zhang 534222ddf1SHong Zhang C->ops->transposematmultnumeric = MatTransposeMatMultNumeric_MPIAIJ_MPIDense; 54*3ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 558949adfdSHong Zhang } 568949adfdSHong Zhang 57d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatTransposeMatMultNumeric_MPIAIJ_MPIDense(Mat A, Mat B, Mat C) 58d71ae5a4SJacob Faibussowitsch { 591683a169SBarry Smith const PetscScalar *Barray, *ctarray; 601683a169SBarry Smith PetscScalar *Carray, *btarray; 61b45e3bf4SStefano Zampini PetscInt i, j, m = A->rmap->n, n = A->cmap->n, ldb, BN = B->cmap->N, ldc; 626718818eSStefano Zampini Mat_MatTransMatMult *atb; 636718818eSStefano Zampini Vec bt, ct; 648949adfdSHong Zhang 658949adfdSHong Zhang PetscFunctionBegin; 666718818eSStefano Zampini MatCheckProduct(C, 3); 676718818eSStefano Zampini atb = (Mat_MatTransMatMult *)C->product->data; 6808401ef6SPierre Jolivet PetscCheck(atb, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing product struct"); 696718818eSStefano Zampini if (!BN) { 709566063dSJacob Faibussowitsch PetscCall(MatAssemblyBegin(C, MAT_FINAL_ASSEMBLY)); 719566063dSJacob Faibussowitsch PetscCall(MatAssemblyEnd(C, MAT_FINAL_ASSEMBLY)); 72*3ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 736718818eSStefano Zampini } 746718818eSStefano Zampini bt = atb->bt; 756718818eSStefano Zampini ct = atb->ct; 768949adfdSHong Zhang 77b45e3bf4SStefano Zampini /* transpose local array of B, then copy it to vector bt */ 789566063dSJacob Faibussowitsch PetscCall(MatDenseGetArrayRead(B, &Barray)); 799566063dSJacob Faibussowitsch PetscCall(MatDenseGetLDA(B, &ldb)); 809566063dSJacob Faibussowitsch PetscCall(VecGetArray(bt, &btarray)); 81b45e3bf4SStefano Zampini for (j = 0; j < BN; j++) 829371c9d4SSatish Balay for (i = 0; i < m; i++) btarray[i * BN + j] = Barray[j * ldb + i]; 839566063dSJacob Faibussowitsch PetscCall(VecRestoreArray(bt, &btarray)); 849566063dSJacob Faibussowitsch PetscCall(MatDenseRestoreArrayRead(B, &Barray)); 858949adfdSHong Zhang 868949adfdSHong Zhang /* compute ct = mA^T * cb */ 879566063dSJacob Faibussowitsch PetscCall(MatMultTranspose(atb->mA, bt, ct)); 888949adfdSHong Zhang 89905b3b74SHong Zhang /* transpose local array of ct to matrix C */ 909566063dSJacob Faibussowitsch PetscCall(MatDenseGetArray(C, &Carray)); 919566063dSJacob Faibussowitsch PetscCall(MatDenseGetLDA(C, &ldc)); 929566063dSJacob Faibussowitsch PetscCall(VecGetArrayRead(ct, &ctarray)); 93b45e3bf4SStefano Zampini for (j = 0; j < BN; j++) 949371c9d4SSatish Balay for (i = 0; i < n; i++) Carray[j * ldc + i] = ctarray[i * BN + j]; 959566063dSJacob Faibussowitsch PetscCall(VecRestoreArrayRead(ct, &ctarray)); 969566063dSJacob Faibussowitsch PetscCall(MatDenseRestoreArray(C, &Carray)); 979566063dSJacob Faibussowitsch PetscCall(MatAssemblyBegin(C, MAT_FINAL_ASSEMBLY)); 989566063dSJacob Faibussowitsch PetscCall(MatAssemblyEnd(C, MAT_FINAL_ASSEMBLY)); 99*3ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 1008949adfdSHong Zhang } 101