355 lines
12 KiB
C++
355 lines
12 KiB
C++
//===----------------- Server.cpp - Server Implementation -----------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// Offloading gRPC server for remote host.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include <cmath>
|
|
#include <future>
|
|
|
|
#include "Server.h"
|
|
#include "omptarget.h"
|
|
#include "openmp.grpc.pb.h"
|
|
#include "openmp.pb.h"
|
|
|
|
using grpc::WriteOptions;
|
|
|
|
extern std::promise<void> ShutdownPromise;
|
|
|
|
Status RemoteOffloadImpl::Shutdown(ServerContext *Context, const Null *Request,
|
|
I32 *Reply) {
|
|
SERVER_DBG("Shutting down the server")
|
|
|
|
Reply->set_number(0);
|
|
ShutdownPromise.set_value();
|
|
return Status::OK;
|
|
}
|
|
|
|
Status
|
|
RemoteOffloadImpl::RegisterLib(ServerContext *Context,
|
|
const TargetBinaryDescription *Description,
|
|
I32 *Reply) {
|
|
auto Desc = std::make_unique<__tgt_bin_desc>();
|
|
|
|
unloadTargetBinaryDescription(Description, Desc.get(),
|
|
HostToRemoteDeviceImage);
|
|
PM->RTLs.RegisterLib(Desc.get());
|
|
|
|
if (Descriptions.find((void *)Description->bin_ptr()) != Descriptions.end())
|
|
freeTargetBinaryDescription(
|
|
Descriptions[(void *)Description->bin_ptr()].get());
|
|
else
|
|
Descriptions[(void *)Description->bin_ptr()] = std::move(Desc);
|
|
|
|
SERVER_DBG("Registered library")
|
|
Reply->set_number(0);
|
|
return Status::OK;
|
|
}
|
|
|
|
Status RemoteOffloadImpl::UnregisterLib(ServerContext *Context,
|
|
const Pointer *Request, I32 *Reply) {
|
|
if (Descriptions.find((void *)Request->number()) == Descriptions.end()) {
|
|
Reply->set_number(1);
|
|
return Status::OK;
|
|
}
|
|
|
|
PM->RTLs.UnregisterLib(Descriptions[(void *)Request->number()].get());
|
|
freeTargetBinaryDescription(Descriptions[(void *)Request->number()].get());
|
|
Descriptions.erase((void *)Request->number());
|
|
|
|
SERVER_DBG("Unregistered library")
|
|
Reply->set_number(0);
|
|
return Status::OK;
|
|
}
|
|
|
|
Status RemoteOffloadImpl::IsValidBinary(ServerContext *Context,
|
|
const TargetDeviceImagePtr *DeviceImage,
|
|
I32 *IsValid) {
|
|
__tgt_device_image *Image =
|
|
HostToRemoteDeviceImage[(void *)DeviceImage->image_ptr()];
|
|
|
|
IsValid->set_number(0);
|
|
|
|
for (auto &RTL : PM->RTLs.AllRTLs)
|
|
if (auto Ret = RTL.is_valid_binary(Image)) {
|
|
IsValid->set_number(Ret);
|
|
break;
|
|
}
|
|
|
|
SERVER_DBG("Checked if binary (%p) is valid",
|
|
(void *)(DeviceImage->image_ptr()))
|
|
return Status::OK;
|
|
}
|
|
|
|
Status RemoteOffloadImpl::GetNumberOfDevices(ServerContext *Context,
|
|
const Null *Null,
|
|
I32 *NumberOfDevices) {
|
|
std::call_once(PM->RTLs.initFlag, &RTLsTy::LoadRTLs, &PM->RTLs);
|
|
|
|
int32_t Devices = 0;
|
|
PM->RTLsMtx.lock();
|
|
for (auto &RTL : PM->RTLs.AllRTLs)
|
|
Devices += RTL.NumberOfDevices;
|
|
PM->RTLsMtx.unlock();
|
|
|
|
NumberOfDevices->set_number(Devices);
|
|
|
|
SERVER_DBG("Got number of devices")
|
|
return Status::OK;
|
|
}
|
|
|
|
Status RemoteOffloadImpl::InitDevice(ServerContext *Context,
|
|
const I32 *DeviceNum, I32 *Reply) {
|
|
Reply->set_number(PM->Devices[DeviceNum->number()]->RTL->init_device(
|
|
mapHostRTLDeviceId(DeviceNum->number())));
|
|
|
|
SERVER_DBG("Initialized device %d", DeviceNum->number())
|
|
return Status::OK;
|
|
}
|
|
|
|
Status RemoteOffloadImpl::InitRequires(ServerContext *Context,
|
|
const I64 *RequiresFlag, I32 *Reply) {
|
|
for (auto &Device : PM->Devices)
|
|
if (Device->RTL->init_requires)
|
|
Device->RTL->init_requires(RequiresFlag->number());
|
|
Reply->set_number(RequiresFlag->number());
|
|
|
|
SERVER_DBG("Initialized requires for devices")
|
|
return Status::OK;
|
|
}
|
|
|
|
Status RemoteOffloadImpl::LoadBinary(ServerContext *Context,
|
|
const Binary *Binary, TargetTable *Reply) {
|
|
__tgt_device_image *Image =
|
|
HostToRemoteDeviceImage[(void *)Binary->image_ptr()];
|
|
|
|
Table = PM->Devices[Binary->device_id()]->RTL->load_binary(
|
|
mapHostRTLDeviceId(Binary->device_id()), Image);
|
|
if (Table)
|
|
loadTargetTable(Table, *Reply, Image);
|
|
|
|
SERVER_DBG("Loaded binary (%p) to device %d", (void *)Binary->image_ptr(),
|
|
Binary->device_id())
|
|
return Status::OK;
|
|
}
|
|
|
|
Status RemoteOffloadImpl::IsDataExchangeable(ServerContext *Context,
|
|
const DevicePair *Request,
|
|
I32 *Reply) {
|
|
Reply->set_number(-1);
|
|
if (PM->Devices[mapHostRTLDeviceId(Request->src_dev_id())]
|
|
->RTL->is_data_exchangable)
|
|
Reply->set_number(PM->Devices[mapHostRTLDeviceId(Request->src_dev_id())]
|
|
->RTL->is_data_exchangable(Request->src_dev_id(),
|
|
Request->dst_dev_id()));
|
|
|
|
SERVER_DBG("Checked if data exchangeable between device %d and device %d",
|
|
Request->src_dev_id(), Request->dst_dev_id())
|
|
return Status::OK;
|
|
}
|
|
|
|
Status RemoteOffloadImpl::DataAlloc(ServerContext *Context,
|
|
const AllocData *Request, Pointer *Reply) {
|
|
uint64_t TgtPtr =
|
|
(uint64_t)PM->Devices[Request->device_id()]->RTL->data_alloc(
|
|
mapHostRTLDeviceId(Request->device_id()), Request->size(),
|
|
(void *)Request->hst_ptr(), TARGET_ALLOC_DEFAULT);
|
|
Reply->set_number(TgtPtr);
|
|
|
|
SERVER_DBG("Allocated at " DPxMOD "", DPxPTR((void *)TgtPtr))
|
|
|
|
return Status::OK;
|
|
}
|
|
|
|
Status RemoteOffloadImpl::DataSubmit(ServerContext *Context,
|
|
ServerReader<SubmitData> *Reader,
|
|
I32 *Reply) {
|
|
SubmitData Request;
|
|
uint8_t *HostCopy = nullptr;
|
|
while (Reader->Read(&Request)) {
|
|
if (Request.start() == 0 && Request.size() == Request.data().size()) {
|
|
Reader->SendInitialMetadata();
|
|
|
|
Reply->set_number(PM->Devices[Request.device_id()]->RTL->data_submit(
|
|
mapHostRTLDeviceId(Request.device_id()), (void *)Request.tgt_ptr(),
|
|
(void *)Request.data().data(), Request.data().size()));
|
|
|
|
SERVER_DBG("Submitted %lu bytes async to (%p) on device %d",
|
|
Request.data().size(), (void *)Request.tgt_ptr(),
|
|
Request.device_id())
|
|
|
|
return Status::OK;
|
|
}
|
|
if (!HostCopy) {
|
|
HostCopy = new uint8_t[Request.size()];
|
|
Reader->SendInitialMetadata();
|
|
}
|
|
|
|
memcpy((void *)((char *)HostCopy + Request.start()), Request.data().data(),
|
|
Request.data().size());
|
|
}
|
|
|
|
Reply->set_number(PM->Devices[Request.device_id()]->RTL->data_submit(
|
|
mapHostRTLDeviceId(Request.device_id()), (void *)Request.tgt_ptr(),
|
|
HostCopy, Request.size()));
|
|
|
|
delete[] HostCopy;
|
|
|
|
SERVER_DBG("Submitted %lu bytes to (%p) on device %d", Request.data().size(),
|
|
(void *)Request.tgt_ptr(), Request.device_id())
|
|
|
|
return Status::OK;
|
|
}
|
|
|
|
Status RemoteOffloadImpl::DataRetrieve(ServerContext *Context,
|
|
const RetrieveData *Request,
|
|
ServerWriter<Data> *Writer) {
|
|
auto HstPtr = std::make_unique<char[]>(Request->size());
|
|
|
|
auto Ret = PM->Devices[Request->device_id()]->RTL->data_retrieve(
|
|
mapHostRTLDeviceId(Request->device_id()), HstPtr.get(),
|
|
(void *)Request->tgt_ptr(), Request->size());
|
|
|
|
if (Arena->SpaceAllocated() >= MaxSize)
|
|
Arena->Reset();
|
|
|
|
if (Request->size() > BlockSize) {
|
|
uint64_t Start = 0, End = BlockSize;
|
|
for (auto I = 0; I < ceil((float)Request->size() / BlockSize); I++) {
|
|
auto *Reply = protobuf::Arena::CreateMessage<Data>(Arena.get());
|
|
|
|
Reply->set_start(Start);
|
|
Reply->set_size(Request->size());
|
|
Reply->set_data((char *)HstPtr.get() + Start, End - Start);
|
|
Reply->set_ret(Ret);
|
|
|
|
if (!Writer->Write(*Reply)) {
|
|
CLIENT_DBG("Broken stream when submitting data")
|
|
}
|
|
|
|
SERVER_DBG("Retrieved %lu-%lu/%lu bytes from (%p) on device %d", Start,
|
|
End, Request->size(), (void *)Request->tgt_ptr(),
|
|
mapHostRTLDeviceId(Request->device_id()))
|
|
|
|
Start += BlockSize;
|
|
End += BlockSize;
|
|
if (End >= Request->size())
|
|
End = Request->size();
|
|
}
|
|
} else {
|
|
auto *Reply = protobuf::Arena::CreateMessage<Data>(Arena.get());
|
|
|
|
Reply->set_start(0);
|
|
Reply->set_size(Request->size());
|
|
Reply->set_data((char *)HstPtr.get(), Request->size());
|
|
Reply->set_ret(Ret);
|
|
|
|
SERVER_DBG("Retrieved %lu bytes from (%p) on device %d", Request->size(),
|
|
(void *)Request->tgt_ptr(),
|
|
mapHostRTLDeviceId(Request->device_id()))
|
|
|
|
Writer->WriteLast(*Reply, WriteOptions());
|
|
}
|
|
|
|
return Status::OK;
|
|
}
|
|
|
|
Status RemoteOffloadImpl::DataExchange(ServerContext *Context,
|
|
const ExchangeData *Request,
|
|
I32 *Reply) {
|
|
if (PM->Devices[Request->src_dev_id()]->RTL->data_exchange) {
|
|
int32_t Ret = PM->Devices[Request->src_dev_id()]->RTL->data_exchange(
|
|
mapHostRTLDeviceId(Request->src_dev_id()), (void *)Request->src_ptr(),
|
|
mapHostRTLDeviceId(Request->dst_dev_id()), (void *)Request->dst_ptr(),
|
|
Request->size());
|
|
Reply->set_number(Ret);
|
|
} else
|
|
Reply->set_number(-1);
|
|
|
|
SERVER_DBG(
|
|
"Exchanged data asynchronously from device %d (%p) to device %d (%p) of "
|
|
"size %lu",
|
|
mapHostRTLDeviceId(Request->src_dev_id()), (void *)Request->src_ptr(),
|
|
mapHostRTLDeviceId(Request->dst_dev_id()), (void *)Request->dst_ptr(),
|
|
Request->size())
|
|
return Status::OK;
|
|
}
|
|
|
|
Status RemoteOffloadImpl::DataDelete(ServerContext *Context,
|
|
const DeleteData *Request, I32 *Reply) {
|
|
auto Ret = PM->Devices[Request->device_id()]->RTL->data_delete(
|
|
mapHostRTLDeviceId(Request->device_id()), (void *)Request->tgt_ptr());
|
|
Reply->set_number(Ret);
|
|
|
|
SERVER_DBG("Deleted data from (%p) on device %d", (void *)Request->tgt_ptr(),
|
|
mapHostRTLDeviceId(Request->device_id()))
|
|
return Status::OK;
|
|
}
|
|
|
|
Status RemoteOffloadImpl::RunTargetRegion(ServerContext *Context,
|
|
const TargetRegion *Request,
|
|
I32 *Reply) {
|
|
std::vector<uint8_t> TgtArgs(Request->arg_num());
|
|
for (auto I = 0; I < Request->arg_num(); I++)
|
|
TgtArgs[I] = (uint64_t)Request->tgt_args()[I];
|
|
|
|
std::vector<ptrdiff_t> TgtOffsets(Request->arg_num());
|
|
const auto *TgtOffsetItr = Request->tgt_offsets().begin();
|
|
for (auto I = 0; I < Request->arg_num(); I++, TgtOffsetItr++)
|
|
TgtOffsets[I] = (ptrdiff_t)*TgtOffsetItr;
|
|
|
|
void *TgtEntryPtr = ((__tgt_offload_entry *)Request->tgt_entry_ptr())->addr;
|
|
|
|
int32_t Ret = PM->Devices[Request->device_id()]->RTL->run_region(
|
|
mapHostRTLDeviceId(Request->device_id()), TgtEntryPtr,
|
|
(void **)TgtArgs.data(), TgtOffsets.data(), Request->arg_num());
|
|
|
|
Reply->set_number(Ret);
|
|
|
|
SERVER_DBG("Ran TargetRegion on device %d with %d args",
|
|
mapHostRTLDeviceId(Request->device_id()), Request->arg_num())
|
|
return Status::OK;
|
|
}
|
|
|
|
Status RemoteOffloadImpl::RunTargetTeamRegion(ServerContext *Context,
|
|
const TargetTeamRegion *Request,
|
|
I32 *Reply) {
|
|
std::vector<uint64_t> TgtArgs(Request->arg_num());
|
|
for (auto I = 0; I < Request->arg_num(); I++)
|
|
TgtArgs[I] = (uint64_t)Request->tgt_args()[I];
|
|
|
|
std::vector<ptrdiff_t> TgtOffsets(Request->arg_num());
|
|
const auto *TgtOffsetItr = Request->tgt_offsets().begin();
|
|
for (auto I = 0; I < Request->arg_num(); I++, TgtOffsetItr++)
|
|
TgtOffsets[I] = (ptrdiff_t)*TgtOffsetItr;
|
|
|
|
void *TgtEntryPtr = ((__tgt_offload_entry *)Request->tgt_entry_ptr())->addr;
|
|
|
|
int32_t Ret = PM->Devices[Request->device_id()]->RTL->run_team_region(
|
|
mapHostRTLDeviceId(Request->device_id()), TgtEntryPtr,
|
|
(void **)TgtArgs.data(), TgtOffsets.data(), Request->arg_num(),
|
|
Request->team_num(), Request->thread_limit(), Request->loop_tripcount());
|
|
|
|
Reply->set_number(Ret);
|
|
|
|
SERVER_DBG("Ran TargetTeamRegion on device %d with %d args",
|
|
mapHostRTLDeviceId(Request->device_id()), Request->arg_num())
|
|
return Status::OK;
|
|
}
|
|
|
|
int32_t RemoteOffloadImpl::mapHostRTLDeviceId(int32_t RTLDeviceID) {
|
|
for (auto &RTL : PM->RTLs.UsedRTLs) {
|
|
if (RTLDeviceID - RTL->NumberOfDevices >= 0)
|
|
RTLDeviceID -= RTL->NumberOfDevices;
|
|
else
|
|
break;
|
|
}
|
|
return RTLDeviceID;
|
|
}
|