Slackbot
04/26/2023, 9:05 PMFernando Camargo
04/26/2023, 9:42 PMclass PyTorchTensorDictContainer(DataContainer[Dict[str, torch.Tensor], torch.Tensor]):
@classmethod
def batches_to_batch(
cls,
batches: t.Sequence[Dict[str, torch.Tensor]],
batch_dim: int = 0,
) -> t.Tuple[Dict[str, torch.Tensor], list[int]]:
batch = {}
for key in batches[0].keys():
batch[key] = <http://torch.cat|torch.cat>(
tuple(subbatch[key] for subbatch in batches), dim=batch_dim
)
indices = list(
itertools.accumulate(subbatch[key].shape[batch_dim] for subbatch in batches)
)
indices = [0] + indices
return batch, indices
@classmethod
def batch_to_batches(
cls,
batch: Dict[str, torch.Tensor],
indices: t.Sequence[int],
batch_dim: int = 0,
) -> t.List[Dict[str, torch.Tensor]]:
sizes = [indices[i] - indices[i - 1] for i in range(1, len(indices))]
output: list[Dict[str, torch.Tensor]] = [{} for _ in range(len(sizes))]
for key in batch.keys():
split_tensor = torch.split(batch[key], sizes, dim=batch_dim)
for i, tensor in enumerate(split_tensor):
output[i][key] = tensor
return output
@classmethod
@inject
def to_payload( # pylint: disable=arguments-differ
cls,
batch: Dict[str, torch.Tensor],
batch_dim: int = 0,
) -> Payload:
batch = batch.copy()
for key in batch.keys():
batch[key] = batch[key].cpu().numpy()
return cls.create_payload(
pickle.dumps(batch),
batch_size=batch[key].shape[batch_dim],
meta={"plasma": False},
)
@classmethod
@inject
def from_payload( # pylint: disable=arguments-differ
cls,
payload: Payload,
) -> Dict[str, torch.Tensor]:
ret: Dict[str, np.ndarray] = pickle.loads(payload.data)
output = {}
for key in ret.keys():
output[key] = torch.from_numpy(ret[key]).requires_grad_(False)
return output
@classmethod
@inject
def batch_to_payloads( # pylint: disable=arguments-differ
cls,
batch: Dict[str, torch.Tensor],
indices: t.Sequence[int],
batch_dim: int = 0,
) -> t.List[Payload]:
batches = cls.batch_to_batches(batch, indices, batch_dim)
payloads = [cls.to_payload(i, batch_dim=batch_dim) for i in batches]
return payloads
@classmethod
@inject
def from_batch_payloads( # pylint: disable=arguments-differ
cls,
payloads: t.Sequence[Payload],
batch_dim: int = 0,
) -> t.Tuple[Dict[str, torch.Tensor], list[int]]:
batches = [cls.from_payload(payload) for payload in payloads]
return cls.batches_to_batch(batches, batch_dim)
However, I need to use internal BentoML API, which is far from ideal. Since it's common for some PyTorch models to receive a dict, I think it would be good if BentoML could support if by default. Otherwise, expose and document this API so we can implement a class like this when necessary.Fernando Camargo
04/26/2023, 10:16 PMPytorchModelRunnable
isn't ready for a dict of tensors either.Fernando Camargo
04/26/2023, 10:27 PMmake_pytorch_runnable_method()
as well.sauyon
04/27/2023, 7:27 AMChaoyu
04/27/2023, 3:34 PMFernando Camargo
04/27/2023, 10:48 PMFernando Camargo
04/27/2023, 10:51 PMDictContainer
that for each value in the dict uses AutoContainer
.sauyon
04/27/2023, 10:52 PMsauyon
04/27/2023, 10:53 PMFrost Ming
04/28/2023, 3:49 AMFernando Camargo
04/28/2023, 10:06 PM