This message was deleted.
# ask-for-help
s
This message was deleted.
f
To solve the problem, I created and registered the following class:
Copy code
class PyTorchTensorDictContainer(DataContainer[Dict[str, torch.Tensor], torch.Tensor]):
    @classmethod
    def batches_to_batch(
        cls,
        batches: t.Sequence[Dict[str, torch.Tensor]],
        batch_dim: int = 0,
    ) -> t.Tuple[Dict[str, torch.Tensor], list[int]]:
        batch = {}
        for key in batches[0].keys():
            batch[key] = <http://torch.cat|torch.cat>(
                tuple(subbatch[key] for subbatch in batches), dim=batch_dim
            )
        indices = list(
            itertools.accumulate(subbatch[key].shape[batch_dim] for subbatch in batches)
        )
        indices = [0] + indices
        return batch, indices

    @classmethod
    def batch_to_batches(
        cls,
        batch: Dict[str, torch.Tensor],
        indices: t.Sequence[int],
        batch_dim: int = 0,
    ) -> t.List[Dict[str, torch.Tensor]]:
        sizes = [indices[i] - indices[i - 1] for i in range(1, len(indices))]
        output: list[Dict[str, torch.Tensor]] = [{} for _ in range(len(sizes))]
        for key in batch.keys():
            split_tensor = torch.split(batch[key], sizes, dim=batch_dim)
            for i, tensor in enumerate(split_tensor):
                output[i][key] = tensor
        return output

    @classmethod
    @inject
    def to_payload(  # pylint: disable=arguments-differ
        cls,
        batch: Dict[str, torch.Tensor],
        batch_dim: int = 0,
    ) -> Payload:
        batch = batch.copy()
        for key in batch.keys():
            batch[key] = batch[key].cpu().numpy()

        return cls.create_payload(
            pickle.dumps(batch),
            batch_size=batch[key].shape[batch_dim],
            meta={"plasma": False},
        )

    @classmethod
    @inject
    def from_payload(  # pylint: disable=arguments-differ
        cls,
        payload: Payload,
    ) -> Dict[str, torch.Tensor]:
        ret: Dict[str, np.ndarray] = pickle.loads(payload.data)
        output = {}
        for key in ret.keys():
            output[key] = torch.from_numpy(ret[key]).requires_grad_(False)
        return output

    @classmethod
    @inject
    def batch_to_payloads(  # pylint: disable=arguments-differ
        cls,
        batch: Dict[str, torch.Tensor],
        indices: t.Sequence[int],
        batch_dim: int = 0,
    ) -> t.List[Payload]:
        batches = cls.batch_to_batches(batch, indices, batch_dim)
        payloads = [cls.to_payload(i, batch_dim=batch_dim) for i in batches]
        return payloads

    @classmethod
    @inject
    def from_batch_payloads(  # pylint: disable=arguments-differ
        cls,
        payloads: t.Sequence[Payload],
        batch_dim: int = 0,
    ) -> t.Tuple[Dict[str, torch.Tensor], list[int]]:
        batches = [cls.from_payload(payload) for payload in payloads]
        return cls.batches_to_batch(batches, batch_dim)
However, I need to use internal BentoML API, which is far from ideal. Since it's common for some PyTorch models to receive a dict, I think it would be good if BentoML could support if by default. Otherwise, expose and document this API so we can implement a class like this when necessary.
I'll still need to do something for it to work on GPU, as
PytorchModelRunnable
isn't ready for a dict of tensors either.
Solved it by forking
make_pytorch_runnable_method()
as well.
s
That's something we should definitely get around to fixing; I'll take a look at this tomorrow.
c
cc @Sean we were just talking about supporting this
f
Awesome. Do you want me to create an issue for it, @Chaoyu?
I was thinking about how it could be more generic. Maybe we could have a
DictContainer
that for each value in the dict uses
AutoContainer
.
s
Especially in non-batched workloads our autocontainer should just use pickle, I think.
Probably we should have a picklablecontainer or something, and a separate container for lists and other built in Python things.
f
@Fernando Camargo IMO this feature is good to have. Would you mind opening an issue for it? Let’s move the discussion to it then.
👍 3
f