1d275d636SJeremy L Thompson // Copyright (c) 2017-2025, Lawrence Livermore National Security, LLC and other CEED contributors. 23d8e8822SJeremy L Thompson // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 37d8d0e25Snbeams // 43d8e8822SJeremy L Thompson // SPDX-License-Identifier: BSD-2-Clause 57d8d0e25Snbeams // 63d8e8822SJeremy L Thompson // This file is part of CEED: http://github.com/ceed 77d8d0e25Snbeams 849aac155SJeremy L Thompson #include <ceed.h> 9ec3da8bcSJed Brown #include <ceed/backend.h> 1049aac155SJeremy L Thompson #include <ceed/jit-source/hip/hip-types.h> 113d576824SJeremy L Thompson #include <stddef.h> 123a2968d6SJeremy L Thompson #include <hip/hiprtc.h> 132b730f8bSJeremy L Thompson 14b2165e7aSSebastian Grimberg #include "../hip/ceed-hip-common.h" 157d8d0e25Snbeams #include "../hip/ceed-hip-compile.h" 162b730f8bSJeremy L Thompson #include "ceed-hip-gen-operator-build.h" 172b730f8bSJeremy L Thompson #include "ceed-hip-gen.h" 187d8d0e25Snbeams 197d8d0e25Snbeams //------------------------------------------------------------------------------ 207d8d0e25Snbeams // Destroy operator 217d8d0e25Snbeams //------------------------------------------------------------------------------ 227d8d0e25Snbeams static int CeedOperatorDestroy_Hip_gen(CeedOperator op) { 233a2968d6SJeremy L Thompson Ceed ceed; 247d8d0e25Snbeams CeedOperator_Hip_gen *impl; 256eee1ffcSZach Atkins bool is_composite; 26b7453713SJeremy L Thompson 273a2968d6SJeremy L Thompson CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 282b730f8bSJeremy L Thompson CeedCallBackend(CeedOperatorGetData(op, &impl)); 296eee1ffcSZach Atkins CeedCallBackend(CeedOperatorIsComposite(op, &is_composite)); 306eee1ffcSZach Atkins if (is_composite) { 316eee1ffcSZach Atkins CeedInt num_suboperators; 326eee1ffcSZach Atkins 336eee1ffcSZach Atkins CeedCall(CeedCompositeOperatorGetNumSub(op, &num_suboperators)); 346eee1ffcSZach Atkins for (CeedInt i = 0; i < num_suboperators; i++) { 356eee1ffcSZach Atkins if (impl->streams[i]) CeedCallHip(ceed, hipStreamDestroy(impl->streams[i])); 366eee1ffcSZach Atkins impl->streams[i] = NULL; 376eee1ffcSZach Atkins } 386eee1ffcSZach Atkins } 398b7d3340SJeremy L Thompson if (impl->module) CeedCallHip(ceed, hipModuleUnload(impl->module)); 400183ed61SJeremy L Thompson if (impl->module_assemble_full) CeedCallHip(ceed, hipModuleUnload(impl->module_assemble_full)); 410183ed61SJeremy L Thompson if (impl->module_assemble_diagonal) CeedCallHip(ceed, hipModuleUnload(impl->module_assemble_diagonal)); 423a2968d6SJeremy L Thompson if (impl->points.num_per_elem) CeedCallHip(ceed, hipFree((void **)impl->points.num_per_elem)); 432b730f8bSJeremy L Thompson CeedCallBackend(CeedFree(&impl)); 443a2968d6SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed)); 45e15f9bd0SJeremy L Thompson return CEED_ERROR_SUCCESS; 467d8d0e25Snbeams } 477d8d0e25Snbeams 487d8d0e25Snbeams //------------------------------------------------------------------------------ 497d8d0e25Snbeams // Apply and add to output 507d8d0e25Snbeams //------------------------------------------------------------------------------ 51e9c76bddSJeremy L Thompson static int CeedOperatorApplyAddCore_Hip_gen(CeedOperator op, hipStream_t stream, const CeedScalar *input_arr, CeedScalar *output_arr, 52e9c76bddSJeremy L Thompson bool *is_run_good, CeedRequest *request) { 53ea04d07fSJeremy L Thompson bool is_at_points, is_tensor; 547d8d0e25Snbeams Ceed ceed; 55b7453713SJeremy L Thompson CeedInt num_elem, num_input_fields, num_output_fields; 56b7453713SJeremy L Thompson CeedEvalMode eval_mode; 57b7453713SJeremy L Thompson CeedQFunctionField *qf_input_fields, *qf_output_fields; 587d8d0e25Snbeams CeedQFunction_Hip_gen *qf_data; 59b7453713SJeremy L Thompson CeedQFunction qf; 60b7453713SJeremy L Thompson CeedOperatorField *op_input_fields, *op_output_fields; 61b7453713SJeremy L Thompson CeedOperator_Hip_gen *data; 62b7453713SJeremy L Thompson 638d12f40eSJeremy L Thompson // Creation of the operator 64ea04d07fSJeremy L Thompson CeedCallBackend(CeedOperatorBuildKernel_Hip_gen(op, is_run_good)); 65ea04d07fSJeremy L Thompson if (!(*is_run_good)) return CEED_ERROR_SUCCESS; 66f6eafd79SJeremy L Thompson 67c11e12f4SJeremy L Thompson CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 68c11e12f4SJeremy L Thompson CeedCallBackend(CeedOperatorGetData(op, &data)); 69c11e12f4SJeremy L Thompson CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 70c11e12f4SJeremy L Thompson CeedCallBackend(CeedQFunctionGetData(qf, &qf_data)); 71c11e12f4SJeremy L Thompson CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 728d12f40eSJeremy L Thompson CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 73c11e12f4SJeremy L Thompson CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 74c11e12f4SJeremy L Thompson 757d8d0e25Snbeams // Input vectors 769e201c85SYohann for (CeedInt i = 0; i < num_input_fields; i++) { 772b730f8bSJeremy L Thompson CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 789e201c85SYohann if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 799e201c85SYohann data->fields.inputs[i] = NULL; 807d8d0e25Snbeams } else { 813efc994bSJeremy L Thompson bool is_active; 82b7453713SJeremy L Thompson CeedVector vec; 83b7453713SJeremy L Thompson 847d8d0e25Snbeams // Get input vector 852b730f8bSJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 863efc994bSJeremy L Thompson is_active = vec == CEED_VECTOR_ACTIVE; 87ea04d07fSJeremy L Thompson if (is_active) data->fields.inputs[i] = input_arr; 88ea04d07fSJeremy L Thompson else CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->fields.inputs[i])); 89ea04d07fSJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 907d8d0e25Snbeams } 917d8d0e25Snbeams } 927d8d0e25Snbeams 937d8d0e25Snbeams // Output vectors 949e201c85SYohann for (CeedInt i = 0; i < num_output_fields; i++) { 952b730f8bSJeremy L Thompson CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 969e201c85SYohann if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 979e201c85SYohann data->fields.outputs[i] = NULL; 987d8d0e25Snbeams } else { 993efc994bSJeremy L Thompson bool is_active; 100b7453713SJeremy L Thompson CeedVector vec; 101b7453713SJeremy L Thompson 1027d8d0e25Snbeams // Get output vector 1032b730f8bSJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 1043efc994bSJeremy L Thompson is_active = vec == CEED_VECTOR_ACTIVE; 105ea04d07fSJeremy L Thompson if (is_active) data->fields.outputs[i] = output_arr; 1060c8fbeedSJeremy L Thompson else CeedCallBackend(CeedVectorGetArray(vec, CEED_MEM_DEVICE, &data->fields.outputs[i])); 107ea04d07fSJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 1087d8d0e25Snbeams } 1097d8d0e25Snbeams } 1107d8d0e25Snbeams 1113a2968d6SJeremy L Thompson // Point coordinates, if needed 1123a2968d6SJeremy L Thompson CeedCallBackend(CeedOperatorIsAtPoints(op, &is_at_points)); 1133a2968d6SJeremy L Thompson if (is_at_points) { 1143a2968d6SJeremy L Thompson // Coords 1153a2968d6SJeremy L Thompson CeedVector vec; 1163a2968d6SJeremy L Thompson 1173a2968d6SJeremy L Thompson CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec)); 1183a2968d6SJeremy L Thompson CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->points.coords)); 1193a2968d6SJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 1203a2968d6SJeremy L Thompson 1213a2968d6SJeremy L Thompson // Points per elem 1223a2968d6SJeremy L Thompson if (num_elem != data->points.num_elem) { 1233a2968d6SJeremy L Thompson CeedInt *points_per_elem; 1243a2968d6SJeremy L Thompson const CeedInt num_bytes = num_elem * sizeof(CeedInt); 1253a2968d6SJeremy L Thompson CeedElemRestriction rstr_points = NULL; 1263a2968d6SJeremy L Thompson 1273a2968d6SJeremy L Thompson data->points.num_elem = num_elem; 1283a2968d6SJeremy L Thompson CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, NULL)); 1293a2968d6SJeremy L Thompson CeedCallBackend(CeedCalloc(num_elem, &points_per_elem)); 1303a2968d6SJeremy L Thompson for (CeedInt e = 0; e < num_elem; e++) { 1313a2968d6SJeremy L Thompson CeedInt num_points_elem; 1323a2968d6SJeremy L Thompson 1333a2968d6SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetNumPointsInElement(rstr_points, e, &num_points_elem)); 1343a2968d6SJeremy L Thompson points_per_elem[e] = num_points_elem; 1353a2968d6SJeremy L Thompson } 1363a2968d6SJeremy L Thompson if (data->points.num_per_elem) CeedCallHip(ceed, hipFree((void **)data->points.num_per_elem)); 1373a2968d6SJeremy L Thompson CeedCallHip(ceed, hipMalloc((void **)&data->points.num_per_elem, num_bytes)); 1383a2968d6SJeremy L Thompson CeedCallHip(ceed, hipMemcpy((void *)data->points.num_per_elem, points_per_elem, num_bytes, hipMemcpyHostToDevice)); 1393a2968d6SJeremy L Thompson CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points)); 1403a2968d6SJeremy L Thompson CeedCallBackend(CeedFree(&points_per_elem)); 1413a2968d6SJeremy L Thompson } 1423a2968d6SJeremy L Thompson } 1433a2968d6SJeremy L Thompson 1447d8d0e25Snbeams // Get context data 1452b730f8bSJeremy L Thompson CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &qf_data->d_c)); 1467d8d0e25Snbeams 1477d8d0e25Snbeams // Apply operator 1483a2968d6SJeremy L Thompson void *opargs[] = {(void *)&num_elem, &qf_data->d_c, &data->indices, &data->fields, &data->B, &data->G, &data->W, &data->points}; 149b7453713SJeremy L Thompson 1509123fb08SJeremy L Thompson CeedCallBackend(CeedOperatorHasTensorBases(op, &is_tensor)); 151a61b1c91SJeremy L Thompson CeedInt block_sizes[3] = {data->thread_1d, ((!is_tensor || data->dim == 1) ? 1 : data->thread_1d), -1}; 152f82027a4SJeremy L Thompson 153f82027a4SJeremy L Thompson if (is_tensor) { 15474398b5aSJeremy L Thompson CeedCallBackend(BlockGridCalculate_Hip_gen(data->dim, num_elem, data->max_P_1d, data->Q_1d, block_sizes)); 15590c30374SJeremy L Thompson if (is_at_points) block_sizes[2] = 1; 156f82027a4SJeremy L Thompson } else { 157a61b1c91SJeremy L Thompson CeedInt elems_per_block = 64 * data->thread_1d > 256 ? 256 / data->thread_1d : 64; 158f82027a4SJeremy L Thompson 159f82027a4SJeremy L Thompson elems_per_block = elems_per_block > 0 ? elems_per_block : 1; 160f82027a4SJeremy L Thompson block_sizes[2] = elems_per_block; 161f82027a4SJeremy L Thompson } 16274398b5aSJeremy L Thompson if (data->dim == 1 || !is_tensor) { 1632b730f8bSJeremy L Thompson CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0); 164a61b1c91SJeremy L Thompson CeedInt sharedMem = block_sizes[2] * data->thread_1d * sizeof(CeedScalar); 165b7453713SJeremy L Thompson 1668d12f40eSJeremy L Thompson CeedCallBackend( 167e9c76bddSJeremy L Thompson CeedTryRunKernelDimShared_Hip(ceed, data->op, stream, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, is_run_good, opargs)); 16874398b5aSJeremy L Thompson } else if (data->dim == 2) { 1692b730f8bSJeremy L Thompson CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0); 170a61b1c91SJeremy L Thompson CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar); 171b7453713SJeremy L Thompson 1728d12f40eSJeremy L Thompson CeedCallBackend( 173e9c76bddSJeremy L Thompson CeedTryRunKernelDimShared_Hip(ceed, data->op, stream, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, is_run_good, opargs)); 17474398b5aSJeremy L Thompson } else if (data->dim == 3) { 1752b730f8bSJeremy L Thompson CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0); 176a61b1c91SJeremy L Thompson CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar); 177b7453713SJeremy L Thompson 1788d12f40eSJeremy L Thompson CeedCallBackend( 179e9c76bddSJeremy L Thompson CeedTryRunKernelDimShared_Hip(ceed, data->op, stream, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, is_run_good, opargs)); 1807d8d0e25Snbeams } 1817d8d0e25Snbeams 1827d8d0e25Snbeams // Restore input arrays 1839e201c85SYohann for (CeedInt i = 0; i < num_input_fields; i++) { 1842b730f8bSJeremy L Thompson CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 1859e201c85SYohann if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 1867d8d0e25Snbeams } else { 1873efc994bSJeremy L Thompson bool is_active; 188b7453713SJeremy L Thompson CeedVector vec; 189b7453713SJeremy L Thompson 1902b730f8bSJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 1913efc994bSJeremy L Thompson is_active = vec == CEED_VECTOR_ACTIVE; 192ea04d07fSJeremy L Thompson if (!is_active) CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->fields.inputs[i])); 193ea04d07fSJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 1947d8d0e25Snbeams } 1957d8d0e25Snbeams } 1967d8d0e25Snbeams 1977d8d0e25Snbeams // Restore output arrays 1989e201c85SYohann for (CeedInt i = 0; i < num_output_fields; i++) { 1992b730f8bSJeremy L Thompson CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 2009e201c85SYohann if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 2017d8d0e25Snbeams } else { 2023efc994bSJeremy L Thompson bool is_active; 203b7453713SJeremy L Thompson CeedVector vec; 204b7453713SJeremy L Thompson 2052b730f8bSJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 2063efc994bSJeremy L Thompson is_active = vec == CEED_VECTOR_ACTIVE; 207ea04d07fSJeremy L Thompson if (!is_active) CeedCallBackend(CeedVectorRestoreArray(vec, &data->fields.outputs[i])); 208ea04d07fSJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 2097d8d0e25Snbeams } 2107d8d0e25Snbeams } 2117d8d0e25Snbeams 2123a2968d6SJeremy L Thompson // Restore point coordinates, if needed 2133a2968d6SJeremy L Thompson if (is_at_points) { 2143a2968d6SJeremy L Thompson CeedVector vec; 2153a2968d6SJeremy L Thompson 2163a2968d6SJeremy L Thompson CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec)); 2173a2968d6SJeremy L Thompson CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->points.coords)); 2183a2968d6SJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 2193a2968d6SJeremy L Thompson } 2203a2968d6SJeremy L Thompson 2217d8d0e25Snbeams // Restore context data 2222b730f8bSJeremy L Thompson CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &qf_data->d_c)); 2238d12f40eSJeremy L Thompson 2248d12f40eSJeremy L Thompson // Cleanup 2259bc66399SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed)); 226c11e12f4SJeremy L Thompson CeedCallBackend(CeedQFunctionDestroy(&qf)); 227ea04d07fSJeremy L Thompson if (!(*is_run_good)) data->use_fallback = true; 228ea04d07fSJeremy L Thompson return CEED_ERROR_SUCCESS; 229ea04d07fSJeremy L Thompson } 2308d12f40eSJeremy L Thompson 231ea04d07fSJeremy L Thompson static int CeedOperatorApplyAdd_Hip_gen(CeedOperator op, CeedVector input_vec, CeedVector output_vec, CeedRequest *request) { 232ea04d07fSJeremy L Thompson bool is_run_good = false; 233ea04d07fSJeremy L Thompson const CeedScalar *input_arr = NULL; 234ea04d07fSJeremy L Thompson CeedScalar *output_arr = NULL; 235ea04d07fSJeremy L Thompson 236ea04d07fSJeremy L Thompson // Try to run kernel 237ea04d07fSJeremy L Thompson if (input_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(input_vec, CEED_MEM_DEVICE, &input_arr)); 238ea04d07fSJeremy L Thompson if (output_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArray(output_vec, CEED_MEM_DEVICE, &output_arr)); 239087855afSJeremy L Thompson CeedCallBackend(CeedOperatorApplyAddCore_Hip_gen(op, NULL, input_arr, output_arr, &is_run_good, request)); 240ea04d07fSJeremy L Thompson if (input_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorRestoreArrayRead(input_vec, &input_arr)); 241ea04d07fSJeremy L Thompson if (output_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorRestoreArray(output_vec, &output_arr)); 242ea04d07fSJeremy L Thompson 243ea04d07fSJeremy L Thompson // Fallback on unsuccessful run 244ea04d07fSJeremy L Thompson if (!is_run_good) { 2458d12f40eSJeremy L Thompson CeedOperator op_fallback; 2468d12f40eSJeremy L Thompson 247ea04d07fSJeremy L Thompson CeedDebug256(CeedOperatorReturnCeed(op), CEED_DEBUG_COLOR_SUCCESS, "Falling back to /gpu/hip/ref CeedOperator"); 2488d12f40eSJeremy L Thompson CeedCallBackend(CeedOperatorGetFallback(op, &op_fallback)); 2498d12f40eSJeremy L Thompson CeedCallBackend(CeedOperatorApplyAdd(op_fallback, input_vec, output_vec, request)); 2508d12f40eSJeremy L Thompson } 251e15f9bd0SJeremy L Thompson return CEED_ERROR_SUCCESS; 2527d8d0e25Snbeams } 2537d8d0e25Snbeams 254c99afcd8SJeremy L Thompson static int CeedOperatorApplyAddComposite_Hip_gen(CeedOperator op, CeedVector input_vec, CeedVector output_vec, CeedRequest *request) { 2556eee1ffcSZach Atkins bool is_run_good[CEED_COMPOSITE_MAX] = {true}; 256c99afcd8SJeremy L Thompson CeedInt num_suboperators; 257c99afcd8SJeremy L Thompson const CeedScalar *input_arr = NULL; 2586eee1ffcSZach Atkins CeedScalar *output_arr; 259087855afSJeremy L Thompson Ceed ceed; 2606eee1ffcSZach Atkins CeedOperator_Hip_gen *impl; 261c99afcd8SJeremy L Thompson CeedOperator *sub_operators; 262c99afcd8SJeremy L Thompson 263087855afSJeremy L Thompson CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 2646eee1ffcSZach Atkins CeedCallBackend(CeedOperatorGetData(op, &impl)); 2656eee1ffcSZach Atkins CeedCallBackend(CeedCompositeOperatorGetNumSub(op, &num_suboperators)); 2666eee1ffcSZach Atkins CeedCallBackend(CeedCompositeOperatorGetSubList(op, &sub_operators)); 267c99afcd8SJeremy L Thompson if (input_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(input_vec, CEED_MEM_DEVICE, &input_arr)); 268c99afcd8SJeremy L Thompson if (output_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArray(output_vec, CEED_MEM_DEVICE, &output_arr)); 269c99afcd8SJeremy L Thompson for (CeedInt i = 0; i < num_suboperators; i++) { 270c99afcd8SJeremy L Thompson CeedInt num_elem = 0; 271c99afcd8SJeremy L Thompson 2726eee1ffcSZach Atkins CeedCallBackend(CeedOperatorGetNumElements(sub_operators[i], &num_elem)); 273c99afcd8SJeremy L Thompson if (num_elem > 0) { 2746eee1ffcSZach Atkins if (!impl->streams[i]) CeedCallHip(ceed, hipStreamCreate(&impl->streams[i])); 2756eee1ffcSZach Atkins CeedCallBackend(CeedOperatorApplyAddCore_Hip_gen(sub_operators[i], impl->streams[i], input_arr, output_arr, &is_run_good[i], request)); 2766eee1ffcSZach Atkins } else { 2776eee1ffcSZach Atkins is_run_good[i] = true; 2786eee1ffcSZach Atkins } 2796eee1ffcSZach Atkins } 280087855afSJeremy L Thompson 2816eee1ffcSZach Atkins for (CeedInt i = 0; i < num_suboperators; i++) { 2826eee1ffcSZach Atkins if (impl->streams[i]) { 2836eee1ffcSZach Atkins if (is_run_good[i]) CeedCallHip(ceed, hipStreamSynchronize(impl->streams[i])); 284c99afcd8SJeremy L Thompson } 285c99afcd8SJeremy L Thompson } 286c99afcd8SJeremy L Thompson if (input_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorRestoreArrayRead(input_vec, &input_arr)); 287c99afcd8SJeremy L Thompson if (output_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorRestoreArray(output_vec, &output_arr)); 288087855afSJeremy L Thompson CeedCallHip(ceed, hipDeviceSynchronize()); 289c99afcd8SJeremy L Thompson 290c99afcd8SJeremy L Thompson // Fallback on unsuccessful run 291c99afcd8SJeremy L Thompson for (CeedInt i = 0; i < num_suboperators; i++) { 292c99afcd8SJeremy L Thompson if (!is_run_good[i]) { 293c99afcd8SJeremy L Thompson CeedOperator op_fallback; 294c99afcd8SJeremy L Thompson 295087855afSJeremy L Thompson CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "Falling back to /gpu/hip/ref CeedOperator"); 296c99afcd8SJeremy L Thompson CeedCallBackend(CeedOperatorGetFallback(sub_operators[i], &op_fallback)); 297c99afcd8SJeremy L Thompson CeedCallBackend(CeedOperatorApplyAdd(op_fallback, input_vec, output_vec, request)); 298c99afcd8SJeremy L Thompson } 299c99afcd8SJeremy L Thompson } 300087855afSJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed)); 301c99afcd8SJeremy L Thompson return CEED_ERROR_SUCCESS; 302c99afcd8SJeremy L Thompson } 303c99afcd8SJeremy L Thompson 3047d8d0e25Snbeams //------------------------------------------------------------------------------ 3050183ed61SJeremy L Thompson // AtPoints diagonal assembly 3060183ed61SJeremy L Thompson //------------------------------------------------------------------------------ 3070183ed61SJeremy L Thompson static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Hip_gen(CeedOperator op, CeedVector assembled, CeedRequest *request) { 3080183ed61SJeremy L Thompson Ceed ceed; 3090183ed61SJeremy L Thompson CeedOperator_Hip_gen *data; 3100183ed61SJeremy L Thompson 3110183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 3120183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorGetData(op, &data)); 3130183ed61SJeremy L Thompson 3140183ed61SJeremy L Thompson // Build the assembly kernel 3150183ed61SJeremy L Thompson if (!data->assemble_diagonal && !data->use_assembly_fallback) { 3160183ed61SJeremy L Thompson bool is_build_good = false; 3170183ed61SJeremy L Thompson CeedInt num_active_bases_in, num_active_bases_out; 3180183ed61SJeremy L Thompson CeedOperatorAssemblyData assembly_data; 3190183ed61SJeremy L Thompson 3200183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorGetOperatorAssemblyData(op, &assembly_data)); 3210183ed61SJeremy L Thompson CeedCallBackend( 3220183ed61SJeremy L Thompson CeedOperatorAssemblyDataGetEvalModes(assembly_data, &num_active_bases_in, NULL, NULL, NULL, &num_active_bases_out, NULL, NULL, NULL, NULL)); 3230183ed61SJeremy L Thompson if (num_active_bases_in == num_active_bases_out) { 3240183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorBuildKernel_Hip_gen(op, &is_build_good)); 3250183ed61SJeremy L Thompson if (is_build_good) CeedCallBackend(CeedOperatorBuildKernelDiagonalAssemblyAtPoints_Hip_gen(op, &is_build_good)); 3260183ed61SJeremy L Thompson } 3270183ed61SJeremy L Thompson if (!is_build_good) data->use_assembly_fallback = true; 3280183ed61SJeremy L Thompson } 3290183ed61SJeremy L Thompson 3300183ed61SJeremy L Thompson // Try assembly 3310183ed61SJeremy L Thompson if (!data->use_assembly_fallback) { 3320183ed61SJeremy L Thompson bool is_run_good = true; 3330183ed61SJeremy L Thompson Ceed_Hip *hip_data; 3340183ed61SJeremy L Thompson CeedInt num_elem, num_input_fields, num_output_fields; 3350183ed61SJeremy L Thompson CeedEvalMode eval_mode; 3360183ed61SJeremy L Thompson CeedScalar *assembled_array; 3370183ed61SJeremy L Thompson CeedQFunctionField *qf_input_fields, *qf_output_fields; 3380183ed61SJeremy L Thompson CeedQFunction_Hip_gen *qf_data; 3390183ed61SJeremy L Thompson CeedQFunction qf; 3400183ed61SJeremy L Thompson CeedOperatorField *op_input_fields, *op_output_fields; 3410183ed61SJeremy L Thompson 3420183ed61SJeremy L Thompson CeedCallBackend(CeedGetData(ceed, &hip_data)); 3430183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 3440183ed61SJeremy L Thompson CeedCallBackend(CeedQFunctionGetData(qf, &qf_data)); 3450183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 3460183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 3470183ed61SJeremy L Thompson CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 3480183ed61SJeremy L Thompson 3490183ed61SJeremy L Thompson // Input vectors 3500183ed61SJeremy L Thompson for (CeedInt i = 0; i < num_input_fields; i++) { 3510183ed61SJeremy L Thompson CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 3520183ed61SJeremy L Thompson if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 3530183ed61SJeremy L Thompson data->fields.inputs[i] = NULL; 3540183ed61SJeremy L Thompson } else { 3550183ed61SJeremy L Thompson bool is_active; 3560183ed61SJeremy L Thompson CeedVector vec; 3570183ed61SJeremy L Thompson 3580183ed61SJeremy L Thompson // Get input vector 3590183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 3600183ed61SJeremy L Thompson is_active = vec == CEED_VECTOR_ACTIVE; 3610183ed61SJeremy L Thompson if (is_active) data->fields.inputs[i] = NULL; 3620183ed61SJeremy L Thompson else CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->fields.inputs[i])); 3630183ed61SJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 3640183ed61SJeremy L Thompson } 3650183ed61SJeremy L Thompson } 3660183ed61SJeremy L Thompson 3670183ed61SJeremy L Thompson // Point coordinates 3680183ed61SJeremy L Thompson { 3690183ed61SJeremy L Thompson CeedVector vec; 3700183ed61SJeremy L Thompson 3710183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec)); 3720183ed61SJeremy L Thompson CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->points.coords)); 3730183ed61SJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 3740183ed61SJeremy L Thompson 3750183ed61SJeremy L Thompson // Points per elem 3760183ed61SJeremy L Thompson if (num_elem != data->points.num_elem) { 3770183ed61SJeremy L Thompson CeedInt *points_per_elem; 3780183ed61SJeremy L Thompson const CeedInt num_bytes = num_elem * sizeof(CeedInt); 3790183ed61SJeremy L Thompson CeedElemRestriction rstr_points = NULL; 3800183ed61SJeremy L Thompson 3810183ed61SJeremy L Thompson data->points.num_elem = num_elem; 3820183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, NULL)); 3830183ed61SJeremy L Thompson CeedCallBackend(CeedCalloc(num_elem, &points_per_elem)); 3840183ed61SJeremy L Thompson for (CeedInt e = 0; e < num_elem; e++) { 3850183ed61SJeremy L Thompson CeedInt num_points_elem; 3860183ed61SJeremy L Thompson 3870183ed61SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetNumPointsInElement(rstr_points, e, &num_points_elem)); 3880183ed61SJeremy L Thompson points_per_elem[e] = num_points_elem; 3890183ed61SJeremy L Thompson } 3900183ed61SJeremy L Thompson if (data->points.num_per_elem) CeedCallHip(ceed, hipFree((void **)data->points.num_per_elem)); 3910183ed61SJeremy L Thompson CeedCallHip(ceed, hipMalloc((void **)&data->points.num_per_elem, num_bytes)); 3920183ed61SJeremy L Thompson CeedCallHip(ceed, hipMemcpy((void *)data->points.num_per_elem, points_per_elem, num_bytes, hipMemcpyHostToDevice)); 3930183ed61SJeremy L Thompson CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points)); 3940183ed61SJeremy L Thompson CeedCallBackend(CeedFree(&points_per_elem)); 3950183ed61SJeremy L Thompson } 3960183ed61SJeremy L Thompson } 3970183ed61SJeremy L Thompson 3980183ed61SJeremy L Thompson // Get context data 3990183ed61SJeremy L Thompson CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &qf_data->d_c)); 4000183ed61SJeremy L Thompson 4010183ed61SJeremy L Thompson // Assembly array 4020183ed61SJeremy L Thompson CeedCallBackend(CeedVectorGetArray(assembled, CEED_MEM_DEVICE, &assembled_array)); 4030183ed61SJeremy L Thompson 4040183ed61SJeremy L Thompson // Assemble diagonal 4050183ed61SJeremy L Thompson void *opargs[] = {(void *)&num_elem, &qf_data->d_c, &data->indices, &data->fields, &data->B, &data->G, &data->W, &data->points, &assembled_array}; 4060183ed61SJeremy L Thompson 4070183ed61SJeremy L Thompson CeedInt block_sizes[3] = {data->thread_1d, (data->dim == 1 ? 1 : data->thread_1d), -1}; 4080183ed61SJeremy L Thompson 4090183ed61SJeremy L Thompson CeedCallBackend(BlockGridCalculate_Hip_gen(data->dim, num_elem, data->max_P_1d, data->Q_1d, block_sizes)); 4100183ed61SJeremy L Thompson block_sizes[2] = 1; 4110183ed61SJeremy L Thompson if (data->dim == 1) { 4120183ed61SJeremy L Thompson CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0); 4130183ed61SJeremy L Thompson CeedInt sharedMem = block_sizes[2] * data->thread_1d * sizeof(CeedScalar); 4140183ed61SJeremy L Thompson 4150183ed61SJeremy L Thompson CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_diagonal, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2], 4160183ed61SJeremy L Thompson sharedMem, &is_run_good, opargs)); 4170183ed61SJeremy L Thompson } else if (data->dim == 2) { 4180183ed61SJeremy L Thompson CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0); 4190183ed61SJeremy L Thompson CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar); 4200183ed61SJeremy L Thompson 4210183ed61SJeremy L Thompson CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_diagonal, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2], 4220183ed61SJeremy L Thompson sharedMem, &is_run_good, opargs)); 4230183ed61SJeremy L Thompson } else if (data->dim == 3) { 4240183ed61SJeremy L Thompson CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0); 4250183ed61SJeremy L Thompson CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar); 4260183ed61SJeremy L Thompson 4270183ed61SJeremy L Thompson CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_diagonal, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2], 4280183ed61SJeremy L Thompson sharedMem, &is_run_good, opargs)); 4290183ed61SJeremy L Thompson } 430*692716b7SZach Atkins CeedCallHip(ceed, hipDeviceSynchronize()); 4310183ed61SJeremy L Thompson 4320183ed61SJeremy L Thompson // Restore input arrays 4330183ed61SJeremy L Thompson for (CeedInt i = 0; i < num_input_fields; i++) { 4340183ed61SJeremy L Thompson CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 4350183ed61SJeremy L Thompson if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 4360183ed61SJeremy L Thompson } else { 4370183ed61SJeremy L Thompson bool is_active; 4380183ed61SJeremy L Thompson CeedVector vec; 4390183ed61SJeremy L Thompson 4400183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 4410183ed61SJeremy L Thompson is_active = vec == CEED_VECTOR_ACTIVE; 4420183ed61SJeremy L Thompson if (!is_active) CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->fields.inputs[i])); 4430183ed61SJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 4440183ed61SJeremy L Thompson } 4450183ed61SJeremy L Thompson } 4460183ed61SJeremy L Thompson 4470183ed61SJeremy L Thompson // Restore point coordinates 4480183ed61SJeremy L Thompson { 4490183ed61SJeremy L Thompson CeedVector vec; 4500183ed61SJeremy L Thompson 4510183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec)); 4520183ed61SJeremy L Thompson CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->points.coords)); 4530183ed61SJeremy L Thompson CeedCallBackend(CeedVectorDestroy(&vec)); 4540183ed61SJeremy L Thompson } 4550183ed61SJeremy L Thompson 4560183ed61SJeremy L Thompson // Restore context data 4570183ed61SJeremy L Thompson CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &qf_data->d_c)); 4580183ed61SJeremy L Thompson 4590183ed61SJeremy L Thompson // Restore assembly array 4600183ed61SJeremy L Thompson CeedCallBackend(CeedVectorRestoreArray(assembled, &assembled_array)); 4610183ed61SJeremy L Thompson 4620183ed61SJeremy L Thompson // Cleanup 4630183ed61SJeremy L Thompson CeedCallBackend(CeedQFunctionDestroy(&qf)); 4640183ed61SJeremy L Thompson if (!is_run_good) data->use_assembly_fallback = true; 4650183ed61SJeremy L Thompson } 4660183ed61SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed)); 4670183ed61SJeremy L Thompson 4680183ed61SJeremy L Thompson // Fallback, if needed 4690183ed61SJeremy L Thompson if (data->use_assembly_fallback) { 4700183ed61SJeremy L Thompson CeedOperator op_fallback; 4710183ed61SJeremy L Thompson 4720183ed61SJeremy L Thompson CeedDebug256(CeedOperatorReturnCeed(op), CEED_DEBUG_COLOR_SUCCESS, "Falling back to /gpu/hip/ref CeedOperator"); 4730183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorGetFallback(op, &op_fallback)); 4740183ed61SJeremy L Thompson CeedCallBackend(CeedOperatorLinearAssembleAddDiagonal(op_fallback, assembled, request)); 4750183ed61SJeremy L Thompson return CEED_ERROR_SUCCESS; 4760183ed61SJeremy L Thompson } 4770183ed61SJeremy L Thompson return CEED_ERROR_SUCCESS; 4780183ed61SJeremy L Thompson } 4790183ed61SJeremy L Thompson 4800183ed61SJeremy L Thompson //------------------------------------------------------------------------------ 481*692716b7SZach Atkins // AtPoints full assembly 482*692716b7SZach Atkins //------------------------------------------------------------------------------ 483*692716b7SZach Atkins static int CeedSingleOperatorAssembleAtPoints_Hip_gen(CeedOperator op, CeedInt offset, CeedVector assembled) { 484*692716b7SZach Atkins Ceed ceed; 485*692716b7SZach Atkins CeedOperator_Hip_gen *data; 486*692716b7SZach Atkins 487*692716b7SZach Atkins CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 488*692716b7SZach Atkins CeedCallBackend(CeedOperatorGetData(op, &data)); 489*692716b7SZach Atkins 490*692716b7SZach Atkins // Build the assembly kernel 491*692716b7SZach Atkins if (!data->assemble_full && !data->use_assembly_fallback) { 492*692716b7SZach Atkins bool is_build_good = false; 493*692716b7SZach Atkins CeedInt num_active_bases_in, num_active_bases_out; 494*692716b7SZach Atkins CeedOperatorAssemblyData assembly_data; 495*692716b7SZach Atkins 496*692716b7SZach Atkins CeedCallBackend(CeedOperatorGetOperatorAssemblyData(op, &assembly_data)); 497*692716b7SZach Atkins CeedCallBackend( 498*692716b7SZach Atkins CeedOperatorAssemblyDataGetEvalModes(assembly_data, &num_active_bases_in, NULL, NULL, NULL, &num_active_bases_out, NULL, NULL, NULL, NULL)); 499*692716b7SZach Atkins if (num_active_bases_in == num_active_bases_out) { 500*692716b7SZach Atkins CeedCallBackend(CeedOperatorBuildKernel_Hip_gen(op, &is_build_good)); 501*692716b7SZach Atkins if (is_build_good) CeedCallBackend(CeedOperatorBuildKernelFullAssemblyAtPoints_Hip_gen(op, &is_build_good)); 502*692716b7SZach Atkins } 503*692716b7SZach Atkins if (!is_build_good) { 504*692716b7SZach Atkins CeedDebug(ceed, "Single Operator Assemble at Points compile failed, using fallback\n"); 505*692716b7SZach Atkins data->use_assembly_fallback = true; 506*692716b7SZach Atkins } 507*692716b7SZach Atkins } 508*692716b7SZach Atkins 509*692716b7SZach Atkins // Try assembly 510*692716b7SZach Atkins if (!data->use_assembly_fallback) { 511*692716b7SZach Atkins bool is_run_good = true; 512*692716b7SZach Atkins Ceed_Hip *Hip_data; 513*692716b7SZach Atkins CeedInt num_elem, num_input_fields, num_output_fields; 514*692716b7SZach Atkins CeedEvalMode eval_mode; 515*692716b7SZach Atkins CeedScalar *assembled_array; 516*692716b7SZach Atkins CeedQFunctionField *qf_input_fields, *qf_output_fields; 517*692716b7SZach Atkins CeedQFunction_Hip_gen *qf_data; 518*692716b7SZach Atkins CeedQFunction qf; 519*692716b7SZach Atkins CeedOperatorField *op_input_fields, *op_output_fields; 520*692716b7SZach Atkins 521*692716b7SZach Atkins CeedCallBackend(CeedGetData(ceed, &Hip_data)); 522*692716b7SZach Atkins CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 523*692716b7SZach Atkins CeedCallBackend(CeedQFunctionGetData(qf, &qf_data)); 524*692716b7SZach Atkins CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 525*692716b7SZach Atkins CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 526*692716b7SZach Atkins CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 527*692716b7SZach Atkins CeedDebug(ceed, "Running single operator assemble for /gpu/hip/gen\n"); 528*692716b7SZach Atkins 529*692716b7SZach Atkins // Input vectors 530*692716b7SZach Atkins for (CeedInt i = 0; i < num_input_fields; i++) { 531*692716b7SZach Atkins CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 532*692716b7SZach Atkins if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 533*692716b7SZach Atkins data->fields.inputs[i] = NULL; 534*692716b7SZach Atkins } else { 535*692716b7SZach Atkins bool is_active; 536*692716b7SZach Atkins CeedVector vec; 537*692716b7SZach Atkins 538*692716b7SZach Atkins // Get input vector 539*692716b7SZach Atkins CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 540*692716b7SZach Atkins is_active = vec == CEED_VECTOR_ACTIVE; 541*692716b7SZach Atkins if (is_active) data->fields.inputs[i] = NULL; 542*692716b7SZach Atkins else CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->fields.inputs[i])); 543*692716b7SZach Atkins CeedCallBackend(CeedVectorDestroy(&vec)); 544*692716b7SZach Atkins } 545*692716b7SZach Atkins } 546*692716b7SZach Atkins 547*692716b7SZach Atkins // Point coordinates 548*692716b7SZach Atkins { 549*692716b7SZach Atkins CeedVector vec; 550*692716b7SZach Atkins 551*692716b7SZach Atkins CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec)); 552*692716b7SZach Atkins CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->points.coords)); 553*692716b7SZach Atkins CeedCallBackend(CeedVectorDestroy(&vec)); 554*692716b7SZach Atkins 555*692716b7SZach Atkins // Points per elem 556*692716b7SZach Atkins if (num_elem != data->points.num_elem) { 557*692716b7SZach Atkins CeedInt *points_per_elem; 558*692716b7SZach Atkins const CeedInt num_bytes = num_elem * sizeof(CeedInt); 559*692716b7SZach Atkins CeedElemRestriction rstr_points = NULL; 560*692716b7SZach Atkins 561*692716b7SZach Atkins data->points.num_elem = num_elem; 562*692716b7SZach Atkins CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, NULL)); 563*692716b7SZach Atkins CeedCallBackend(CeedCalloc(num_elem, &points_per_elem)); 564*692716b7SZach Atkins for (CeedInt e = 0; e < num_elem; e++) { 565*692716b7SZach Atkins CeedInt num_points_elem; 566*692716b7SZach Atkins 567*692716b7SZach Atkins CeedCallBackend(CeedElemRestrictionGetNumPointsInElement(rstr_points, e, &num_points_elem)); 568*692716b7SZach Atkins points_per_elem[e] = num_points_elem; 569*692716b7SZach Atkins } 570*692716b7SZach Atkins if (data->points.num_per_elem) CeedCallHip(ceed, hipFree((void **)data->points.num_per_elem)); 571*692716b7SZach Atkins CeedCallHip(ceed, hipMalloc((void **)&data->points.num_per_elem, num_bytes)); 572*692716b7SZach Atkins CeedCallHip(ceed, hipMemcpy((void *)data->points.num_per_elem, points_per_elem, num_bytes, hipMemcpyHostToDevice)); 573*692716b7SZach Atkins CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points)); 574*692716b7SZach Atkins CeedCallBackend(CeedFree(&points_per_elem)); 575*692716b7SZach Atkins } 576*692716b7SZach Atkins } 577*692716b7SZach Atkins 578*692716b7SZach Atkins // Get context data 579*692716b7SZach Atkins CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &qf_data->d_c)); 580*692716b7SZach Atkins 581*692716b7SZach Atkins // Assembly array 582*692716b7SZach Atkins CeedCallBackend(CeedVectorGetArray(assembled, CEED_MEM_DEVICE, &assembled_array)); 583*692716b7SZach Atkins CeedScalar *assembled_offset_array = &assembled_array[offset]; 584*692716b7SZach Atkins 585*692716b7SZach Atkins // Assemble diagonal 586*692716b7SZach Atkins void *opargs[] = {(void *)&num_elem, &qf_data->d_c, &data->indices, &data->fields, &data->B, 587*692716b7SZach Atkins &data->G, &data->W, &data->points, &assembled_offset_array}; 588*692716b7SZach Atkins 589*692716b7SZach Atkins CeedInt block_sizes[3] = {data->thread_1d, (data->dim == 1 ? 1 : data->thread_1d), -1}; 590*692716b7SZach Atkins 591*692716b7SZach Atkins CeedCallBackend(BlockGridCalculate_Hip_gen(data->dim, num_elem, data->max_P_1d, data->Q_1d, block_sizes)); 592*692716b7SZach Atkins block_sizes[2] = 1; 593*692716b7SZach Atkins if (data->dim == 1) { 594*692716b7SZach Atkins CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0); 595*692716b7SZach Atkins CeedInt sharedMem = block_sizes[2] * data->thread_1d * sizeof(CeedScalar); 596*692716b7SZach Atkins 597*692716b7SZach Atkins CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_full, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, 598*692716b7SZach Atkins &is_run_good, opargs)); 599*692716b7SZach Atkins } else if (data->dim == 2) { 600*692716b7SZach Atkins CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0); 601*692716b7SZach Atkins CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar); 602*692716b7SZach Atkins 603*692716b7SZach Atkins CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_full, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, 604*692716b7SZach Atkins &is_run_good, opargs)); 605*692716b7SZach Atkins } else if (data->dim == 3) { 606*692716b7SZach Atkins CeedInt grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0); 607*692716b7SZach Atkins CeedInt sharedMem = block_sizes[2] * data->thread_1d * data->thread_1d * sizeof(CeedScalar); 608*692716b7SZach Atkins 609*692716b7SZach Atkins CeedCallBackend(CeedTryRunKernelDimShared_Hip(ceed, data->assemble_full, NULL, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, 610*692716b7SZach Atkins &is_run_good, opargs)); 611*692716b7SZach Atkins } 612*692716b7SZach Atkins CeedCallHip(ceed, hipDeviceSynchronize()); 613*692716b7SZach Atkins 614*692716b7SZach Atkins // Restore input arrays 615*692716b7SZach Atkins for (CeedInt i = 0; i < num_input_fields; i++) { 616*692716b7SZach Atkins CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 617*692716b7SZach Atkins if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 618*692716b7SZach Atkins } else { 619*692716b7SZach Atkins bool is_active; 620*692716b7SZach Atkins CeedVector vec; 621*692716b7SZach Atkins 622*692716b7SZach Atkins CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 623*692716b7SZach Atkins is_active = vec == CEED_VECTOR_ACTIVE; 624*692716b7SZach Atkins if (!is_active) CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->fields.inputs[i])); 625*692716b7SZach Atkins CeedCallBackend(CeedVectorDestroy(&vec)); 626*692716b7SZach Atkins } 627*692716b7SZach Atkins } 628*692716b7SZach Atkins 629*692716b7SZach Atkins // Restore point coordinates 630*692716b7SZach Atkins { 631*692716b7SZach Atkins CeedVector vec; 632*692716b7SZach Atkins 633*692716b7SZach Atkins CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec)); 634*692716b7SZach Atkins CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->points.coords)); 635*692716b7SZach Atkins CeedCallBackend(CeedVectorDestroy(&vec)); 636*692716b7SZach Atkins } 637*692716b7SZach Atkins 638*692716b7SZach Atkins // Restore context data 639*692716b7SZach Atkins CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &qf_data->d_c)); 640*692716b7SZach Atkins 641*692716b7SZach Atkins // Restore assembly array 642*692716b7SZach Atkins CeedCallBackend(CeedVectorRestoreArray(assembled, &assembled_array)); 643*692716b7SZach Atkins 644*692716b7SZach Atkins // Cleanup 645*692716b7SZach Atkins CeedCallBackend(CeedQFunctionDestroy(&qf)); 646*692716b7SZach Atkins if (!is_run_good) { 647*692716b7SZach Atkins CeedDebug(ceed, "Single Operator Assemble at Points run failed, using fallback\n"); 648*692716b7SZach Atkins data->use_assembly_fallback = true; 649*692716b7SZach Atkins } 650*692716b7SZach Atkins } 651*692716b7SZach Atkins CeedCallBackend(CeedDestroy(&ceed)); 652*692716b7SZach Atkins 653*692716b7SZach Atkins // Fallback, if needed 654*692716b7SZach Atkins if (data->use_assembly_fallback) { 655*692716b7SZach Atkins CeedOperator op_fallback; 656*692716b7SZach Atkins 657*692716b7SZach Atkins CeedDebug256(CeedOperatorReturnCeed(op), CEED_DEBUG_COLOR_SUCCESS, "Falling back to /gpu/hip/ref CeedOperator"); 658*692716b7SZach Atkins CeedCallBackend(CeedOperatorGetFallback(op, &op_fallback)); 659*692716b7SZach Atkins CeedCallBackend(CeedSingleOperatorAssemble(op_fallback, offset, assembled)); 660*692716b7SZach Atkins return CEED_ERROR_SUCCESS; 661*692716b7SZach Atkins } 662*692716b7SZach Atkins return CEED_ERROR_SUCCESS; 663*692716b7SZach Atkins } 664*692716b7SZach Atkins 665*692716b7SZach Atkins //------------------------------------------------------------------------------ 6667d8d0e25Snbeams // Create operator 6677d8d0e25Snbeams //------------------------------------------------------------------------------ 6687d8d0e25Snbeams int CeedOperatorCreate_Hip_gen(CeedOperator op) { 6690183ed61SJeremy L Thompson bool is_composite, is_at_points; 6707d8d0e25Snbeams Ceed ceed; 6717d8d0e25Snbeams CeedOperator_Hip_gen *impl; 6727d8d0e25Snbeams 673b7453713SJeremy L Thompson CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 6742b730f8bSJeremy L Thompson CeedCallBackend(CeedCalloc(1, &impl)); 6752b730f8bSJeremy L Thompson CeedCallBackend(CeedOperatorSetData(op, impl)); 676c99afcd8SJeremy L Thompson CeedCall(CeedOperatorIsComposite(op, &is_composite)); 677c99afcd8SJeremy L Thompson if (is_composite) { 678c99afcd8SJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAddComposite", CeedOperatorApplyAddComposite_Hip_gen)); 679c99afcd8SJeremy L Thompson } else { 6802b730f8bSJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Hip_gen)); 681c99afcd8SJeremy L Thompson } 6820183ed61SJeremy L Thompson CeedCall(CeedOperatorIsAtPoints(op, &is_at_points)); 6830183ed61SJeremy L Thompson if (is_at_points) { 6840183ed61SJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddDiagonal", CeedOperatorLinearAssembleAddDiagonalAtPoints_Hip_gen)); 685*692716b7SZach Atkins CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleSingle", CeedSingleOperatorAssembleAtPoints_Hip_gen)); 6860183ed61SJeremy L Thompson } 6872b730f8bSJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Hip_gen)); 6889bc66399SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed)); 689e15f9bd0SJeremy L Thompson return CEED_ERROR_SUCCESS; 6907d8d0e25Snbeams } 6912a86cc9dSSebastian Grimberg 6927d8d0e25Snbeams //------------------------------------------------------------------------------ 693