13d8e8822SJeremy L Thompson // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 23d8e8822SJeremy L Thompson // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 37d8d0e25Snbeams // 43d8e8822SJeremy L Thompson // SPDX-License-Identifier: BSD-2-Clause 57d8d0e25Snbeams // 63d8e8822SJeremy L Thompson // This file is part of CEED: http://github.com/ceed 77d8d0e25Snbeams 8ec3da8bcSJed Brown #include <ceed/ceed.h> 9ec3da8bcSJed Brown #include <ceed/backend.h> 103d576824SJeremy L Thompson #include <stddef.h> 117d8d0e25Snbeams #include "ceed-hip-gen.h" 127d8d0e25Snbeams #include "ceed-hip-gen-operator-build.h" 137d8d0e25Snbeams #include "../hip/ceed-hip-compile.h" 147d8d0e25Snbeams 157d8d0e25Snbeams //------------------------------------------------------------------------------ 167d8d0e25Snbeams // Destroy operator 177d8d0e25Snbeams //------------------------------------------------------------------------------ 187d8d0e25Snbeams static int CeedOperatorDestroy_Hip_gen(CeedOperator op) { 197d8d0e25Snbeams int ierr; 207d8d0e25Snbeams CeedOperator_Hip_gen *impl; 21e15f9bd0SJeremy L Thompson ierr = CeedOperatorGetData(op, &impl); CeedChkBackend(ierr); 22e15f9bd0SJeremy L Thompson ierr = CeedFree(&impl); CeedChkBackend(ierr); 23e15f9bd0SJeremy L Thompson return CEED_ERROR_SUCCESS; 247d8d0e25Snbeams } 257d8d0e25Snbeams 267d8d0e25Snbeams //------------------------------------------------------------------------------ 277d8d0e25Snbeams // Apply and add to output 287d8d0e25Snbeams //------------------------------------------------------------------------------ 29*9e201c85SYohann static int CeedOperatorApplyAdd_Hip_gen(CeedOperator op, CeedVector input_vec, 30*9e201c85SYohann CeedVector output_vec, CeedRequest *request) { 317d8d0e25Snbeams int ierr; 327d8d0e25Snbeams Ceed ceed; 33e15f9bd0SJeremy L Thompson ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr); 347d8d0e25Snbeams CeedOperator_Hip_gen *data; 35e15f9bd0SJeremy L Thompson ierr = CeedOperatorGetData(op, &data); CeedChkBackend(ierr); 367d8d0e25Snbeams CeedQFunction qf; 377d8d0e25Snbeams CeedQFunction_Hip_gen *qf_data; 38e15f9bd0SJeremy L Thompson ierr = CeedOperatorGetQFunction(op, &qf); CeedChkBackend(ierr); 39e15f9bd0SJeremy L Thompson ierr = CeedQFunctionGetData(qf, &qf_data); CeedChkBackend(ierr); 40*9e201c85SYohann CeedInt num_elem, num_input_fields, num_output_fields; 41*9e201c85SYohann ierr = CeedOperatorGetNumElements(op, &num_elem); CeedChkBackend(ierr); 42*9e201c85SYohann CeedOperatorField *op_input_fields, *op_output_fields; 43*9e201c85SYohann ierr = CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, 44*9e201c85SYohann &num_output_fields, &op_output_fields); 45e15f9bd0SJeremy L Thompson CeedChkBackend(ierr); 46*9e201c85SYohann CeedQFunctionField *qf_input_fields, *qf_output_fields; 47*9e201c85SYohann ierr = CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, 48*9e201c85SYohann &qf_output_fields); 49e15f9bd0SJeremy L Thompson CeedChkBackend(ierr); 50*9e201c85SYohann CeedEvalMode eval_mode; 51*9e201c85SYohann CeedVector vec, output_vecs[CEED_FIELD_MAX] = {}; 527d8d0e25Snbeams 537d8d0e25Snbeams //Creation of the operator 54e15f9bd0SJeremy L Thompson ierr = CeedHipGenOperatorBuild(op); CeedChkBackend(ierr); 557d8d0e25Snbeams 567d8d0e25Snbeams // Input vectors 57*9e201c85SYohann for (CeedInt i = 0; i < num_input_fields; i++) { 58*9e201c85SYohann ierr = CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode); 59e15f9bd0SJeremy L Thompson CeedChkBackend(ierr); 60*9e201c85SYohann if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 61*9e201c85SYohann data->fields.inputs[i] = NULL; 627d8d0e25Snbeams } else { 637d8d0e25Snbeams // Get input vector 64*9e201c85SYohann ierr = CeedOperatorFieldGetVector(op_input_fields[i], &vec); 65*9e201c85SYohann CeedChkBackend(ierr); 66*9e201c85SYohann if (vec == CEED_VECTOR_ACTIVE) vec = input_vec; 67*9e201c85SYohann ierr = CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->fields.inputs[i]); 68e15f9bd0SJeremy L Thompson CeedChkBackend(ierr); 697d8d0e25Snbeams } 707d8d0e25Snbeams } 717d8d0e25Snbeams 727d8d0e25Snbeams // Output vectors 73*9e201c85SYohann for (CeedInt i = 0; i < num_output_fields; i++) { 74*9e201c85SYohann ierr = CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode); 75e15f9bd0SJeremy L Thompson CeedChkBackend(ierr); 76*9e201c85SYohann if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 77*9e201c85SYohann data->fields.outputs[i] = NULL; 787d8d0e25Snbeams } else { 797d8d0e25Snbeams // Get output vector 80*9e201c85SYohann ierr = CeedOperatorFieldGetVector(op_output_fields[i], &vec); 81e15f9bd0SJeremy L Thompson CeedChkBackend(ierr); 82*9e201c85SYohann if (vec == CEED_VECTOR_ACTIVE) vec = output_vec; 83*9e201c85SYohann output_vecs[i] = vec; 847d8d0e25Snbeams // Check for multiple output modes 857d8d0e25Snbeams CeedInt index = -1; 867d8d0e25Snbeams for (CeedInt j = 0; j < i; j++) { 87*9e201c85SYohann if (vec == output_vecs[j]) { 887d8d0e25Snbeams index = j; 897d8d0e25Snbeams break; 907d8d0e25Snbeams } 917d8d0e25Snbeams } 927d8d0e25Snbeams if (index == -1) { 93*9e201c85SYohann ierr = CeedVectorGetArray(vec, CEED_MEM_DEVICE, &data->fields.outputs[i]); 94e15f9bd0SJeremy L Thompson CeedChkBackend(ierr); 957d8d0e25Snbeams } else { 96*9e201c85SYohann data->fields.outputs[i] = data->fields.outputs[index]; 977d8d0e25Snbeams } 987d8d0e25Snbeams } 997d8d0e25Snbeams } 1007d8d0e25Snbeams 1017d8d0e25Snbeams // Get context data 102441428dfSJeremy L Thompson ierr = CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &qf_data->d_c); 103e15f9bd0SJeremy L Thompson CeedChkBackend(ierr); 1047d8d0e25Snbeams 1057d8d0e25Snbeams // Apply operator 106*9e201c85SYohann void *opargs[] = {(void *) &num_elem, &qf_data->d_c, &data->indices, 1077d8d0e25Snbeams &data->fields, &data->B, &data->G, &data->W 1087d8d0e25Snbeams }; 1097d8d0e25Snbeams const CeedInt dim = data->dim; 110*9e201c85SYohann const CeedInt Q_1d = data->Q_1d; 111*9e201c85SYohann const CeedInt P_1d = data->max_P_1d; 112*9e201c85SYohann const CeedInt thread_1d = CeedIntMax(Q_1d, P_1d); 113b3e1519bSnbeams CeedInt block_sizes[3]; 114*9e201c85SYohann ierr = BlockGridCalculate_Hip_gen(dim, num_elem, P_1d, Q_1d, block_sizes); 115b3e1519bSnbeams CeedChkBackend(ierr); 1167d8d0e25Snbeams if (dim==1) { 117*9e201c85SYohann CeedInt grid = num_elem/block_sizes[2] + ( ( 118*9e201c85SYohann num_elem/block_sizes[2]*block_sizes[2]<num_elem) 1197d8d0e25Snbeams ? 1 : 0 ); 120*9e201c85SYohann CeedInt sharedMem = block_sizes[2]*thread_1d*sizeof(CeedScalar); 121b3e1519bSnbeams ierr = CeedRunKernelDimSharedHip(ceed, data->op, grid, block_sizes[0], 122b3e1519bSnbeams block_sizes[1], 123b3e1519bSnbeams block_sizes[2], sharedMem, opargs); 1247d8d0e25Snbeams } else if (dim==2) { 125*9e201c85SYohann CeedInt grid = num_elem/block_sizes[2] + ( ( 126*9e201c85SYohann num_elem/block_sizes[2]*block_sizes[2]<num_elem) 1277d8d0e25Snbeams ? 1 : 0 ); 128*9e201c85SYohann CeedInt sharedMem = block_sizes[2]*thread_1d*thread_1d*sizeof(CeedScalar); 129b3e1519bSnbeams ierr = CeedRunKernelDimSharedHip(ceed, data->op, grid, block_sizes[0], 130b3e1519bSnbeams block_sizes[1], 131b3e1519bSnbeams block_sizes[2], sharedMem, opargs); 1327d8d0e25Snbeams } else if (dim==3) { 133*9e201c85SYohann CeedInt grid = num_elem/block_sizes[2] + ( ( 134*9e201c85SYohann num_elem/block_sizes[2]*block_sizes[2]<num_elem) 1357d8d0e25Snbeams ? 1 : 0 ); 136*9e201c85SYohann CeedInt sharedMem = block_sizes[2]*thread_1d*thread_1d*sizeof(CeedScalar); 137b3e1519bSnbeams ierr = CeedRunKernelDimSharedHip(ceed, data->op, grid, block_sizes[0], 138b3e1519bSnbeams block_sizes[1], 139b3e1519bSnbeams block_sizes[2], sharedMem, opargs); 1407d8d0e25Snbeams } 141e15f9bd0SJeremy L Thompson CeedChkBackend(ierr); 1427d8d0e25Snbeams 1437d8d0e25Snbeams // Restore input arrays 144*9e201c85SYohann for (CeedInt i = 0; i < num_input_fields; i++) { 145*9e201c85SYohann ierr = CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode); 146e15f9bd0SJeremy L Thompson CeedChkBackend(ierr); 147*9e201c85SYohann if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 1487d8d0e25Snbeams } else { 149*9e201c85SYohann ierr = CeedOperatorFieldGetVector(op_input_fields[i], &vec); 150*9e201c85SYohann CeedChkBackend(ierr); 151*9e201c85SYohann if (vec == CEED_VECTOR_ACTIVE) vec = input_vec; 152*9e201c85SYohann ierr = CeedVectorRestoreArrayRead(vec, &data->fields.inputs[i]); 153e15f9bd0SJeremy L Thompson CeedChkBackend(ierr); 1547d8d0e25Snbeams } 1557d8d0e25Snbeams } 1567d8d0e25Snbeams 1577d8d0e25Snbeams // Restore output arrays 158*9e201c85SYohann for (CeedInt i = 0; i < num_output_fields; i++) { 159*9e201c85SYohann ierr = CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode); 160e15f9bd0SJeremy L Thompson CeedChkBackend(ierr); 161*9e201c85SYohann if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 1627d8d0e25Snbeams } else { 163*9e201c85SYohann ierr = CeedOperatorFieldGetVector(op_output_fields[i], &vec); 164e15f9bd0SJeremy L Thompson CeedChkBackend(ierr); 165*9e201c85SYohann if (vec == CEED_VECTOR_ACTIVE) vec = output_vec; 1667d8d0e25Snbeams // Check for multiple output modes 1677d8d0e25Snbeams CeedInt index = -1; 1687d8d0e25Snbeams for (CeedInt j = 0; j < i; j++) { 169*9e201c85SYohann if (vec == output_vecs[j]) { 1707d8d0e25Snbeams index = j; 1717d8d0e25Snbeams break; 1727d8d0e25Snbeams } 1737d8d0e25Snbeams } 1747d8d0e25Snbeams if (index == -1) { 175*9e201c85SYohann ierr = CeedVectorRestoreArray(vec, &data->fields.outputs[i]); 176e15f9bd0SJeremy L Thompson CeedChkBackend(ierr); 1777d8d0e25Snbeams } 1787d8d0e25Snbeams } 1797d8d0e25Snbeams } 1807d8d0e25Snbeams 1817d8d0e25Snbeams // Restore context data 182441428dfSJeremy L Thompson ierr = CeedQFunctionRestoreInnerContextData(qf, &qf_data->d_c); 183e15f9bd0SJeremy L Thompson CeedChkBackend(ierr); 184441428dfSJeremy L Thompson 185e15f9bd0SJeremy L Thompson return CEED_ERROR_SUCCESS; 1867d8d0e25Snbeams } 1877d8d0e25Snbeams 1887d8d0e25Snbeams //------------------------------------------------------------------------------ 1897d8d0e25Snbeams // Create operator 1907d8d0e25Snbeams //------------------------------------------------------------------------------ 1917d8d0e25Snbeams int CeedOperatorCreate_Hip_gen(CeedOperator op) { 1927d8d0e25Snbeams int ierr; 1937d8d0e25Snbeams Ceed ceed; 194e15f9bd0SJeremy L Thompson ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr); 1957d8d0e25Snbeams CeedOperator_Hip_gen *impl; 1967d8d0e25Snbeams 197e15f9bd0SJeremy L Thompson ierr = CeedCalloc(1, &impl); CeedChkBackend(ierr); 198e15f9bd0SJeremy L Thompson ierr = CeedOperatorSetData(op, impl); CeedChkBackend(ierr); 1997d8d0e25Snbeams 2007d8d0e25Snbeams ierr = CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", 201e15f9bd0SJeremy L Thompson CeedOperatorApplyAdd_Hip_gen); CeedChkBackend(ierr); 2027d8d0e25Snbeams ierr = CeedSetBackendFunction(ceed, "Operator", op, "Destroy", 203e15f9bd0SJeremy L Thompson CeedOperatorDestroy_Hip_gen); CeedChkBackend(ierr); 204e15f9bd0SJeremy L Thompson return CEED_ERROR_SUCCESS; 2057d8d0e25Snbeams } 2067d8d0e25Snbeams //------------------------------------------------------------------------------ 207