Skip to content

Mistral

Generates texts with using the Mistral Batch API.

MistralClient

Bases: BatchClient

Source code in dactyl_generation/mistral_generation.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
class MistralClient(BatchClient):
    def __init__(self, api_key: str):
        """
        Constructor for Mistral client.

        Args:
            api_key:
        """
        super().__init__()
        self.client = Mistral(api_key=api_key)


    def create_message_batch(self, file_name: str, prompts_df: pd.DataFrame) -> Tuple[List[dict], mistralai.models.UploadFileOut]:
        """
        Creates batch of messages to send to Mistral API.


        Args:
            file_name: Name of file in Mistral API to save as.
            prompts_df: DataFrame containing prompts and generation parameters

        Returns:
            tuple: List of requests sent, UploadFileOut object
        """

        buffer = BytesIO()
        list_of_requests = list()
        messages = prompts_df.drop(columns=[CUSTOM_ID]).to_dict(orient="records")
        for index, message_batch in enumerate(messages):
            request = {
                CUSTOM_ID: prompts_df[CUSTOM_ID].values[index],
                BODY: message_batch
            }
            list_of_requests.append(request)
            buffer.write((json.dumps(request)+"\n").encode("utf-8"))
        file = File(file_name=file_name, content=buffer.getvalue())
        return list_of_requests, self.client.files.upload(file=file, purpose=BATCH)


    def start_batch_job(self, input_file: mistralai.models.UploadFileOut, model: str) -> mistralai.models.BatchJobOut:
        """
        Start batch job from input file stored on Mistral API containing prompts.

        Args:
            input_file: input file object to create job with
            model: model name to use for generation

        Returns:
            batch_job: Batch job object
        """

        batch_job = self.client.batch.jobs.create(
            input_files=[input_file.id],
            model=model,
            endpoint="/v1/chat/completions",
            metadata={"job_type": "testing"}
        )
        return batch_job

    def create_batch_job(self, file_name: str, prompts_df: pd.DataFrame) -> dict:
        """
        Creates batch job for set of prompts given file name to save Mistral prompts to.

        Args:
            file_name: name of file to upload to Mistral API.
            prompts_df: DataFrame containing generation prompts and parameters.

        Returns:
            info: dictionary containing batch job info
        """
        assert(len(prompts_df[MODEL].unique()) == 1)
        model = prompts_df[MODEL].unique()[0]
        prompts, input_file = self.create_message_batch(file_name, prompts_df)
        batch_job = self.start_batch_job(input_file, model)
        input_file = input_file.model_dump(mode="json")
        batch_job = batch_job.model_dump(mode="json")
        return {"batch_job": batch_job, INPUT_FILE: input_file, PROMPTS: prompts, API_CALL: MISTRAL}



    def get_batch_jobs(self) -> BatchJobsOut:
        """
        Helper method to get status of all batch jobs.

        Returns:
            batch_jobs_list: list of all batch jobs
        """
        return self.client.batch.jobs.list(
            metadata={"job_type": "testing"}
        )


    def get_batch_job_output(self, file_path: str) -> pd.DataFrame:
        """
        Gets batch job results using saved metadata from a local JSON file.

        Args:
            file_path: local JSON file containing output of the `create_batch_job` function

        Returns:
            df: pandas DataFrame of generations.
        """
        with open(file_path, "r") as f:
            data = json.load(f)
        job_id = data["batch_job"]["id"]
        output_file = self.client.batch.jobs.get(job_id=job_id).output_file
        content = self.client.files.download(file_id=output_file).read().decode("utf-8")
        json_obj = "[" + ", ".join(content.splitlines()) + "]"
        responses = json.loads(json_obj)
        rows = list()
        for response in responses:
            row = dict()
            row[CUSTOM_ID] = response[CUSTOM_ID]
            row[TEXT] = response[RESPONSE][BODY][CHOICES][0][MESSAGE][CONTENT]
            row[TIMESTAMP] = str(datetime.fromtimestamp(response[RESPONSE][BODY][CREATED], tz=timezone.utc))
            rows.append(row)
        raw_prompts = pd.DataFrame([{**prompt[BODY], **{CUSTOM_ID: prompt[CUSTOM_ID]}} for prompt in data[PROMPTS]])
        generations = pd.DataFrame(rows)
        return generations.merge(raw_prompts, on=CUSTOM_ID,how="left")

__init__(api_key)

Constructor for Mistral client.

Parameters:

Name Type Description Default
api_key str
required
Source code in dactyl_generation/mistral_generation.py
19
20
21
22
23
24
25
26
27
def __init__(self, api_key: str):
    """
    Constructor for Mistral client.

    Args:
        api_key:
    """
    super().__init__()
    self.client = Mistral(api_key=api_key)

create_batch_job(file_name, prompts_df)

Creates batch job for set of prompts given file name to save Mistral prompts to.

Parameters:

Name Type Description Default
file_name str

name of file to upload to Mistral API.

required
prompts_df DataFrame

DataFrame containing generation prompts and parameters.

required

Returns:

Name Type Description
info dict

dictionary containing batch job info

Source code in dactyl_generation/mistral_generation.py
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
def create_batch_job(self, file_name: str, prompts_df: pd.DataFrame) -> dict:
    """
    Creates batch job for set of prompts given file name to save Mistral prompts to.

    Args:
        file_name: name of file to upload to Mistral API.
        prompts_df: DataFrame containing generation prompts and parameters.

    Returns:
        info: dictionary containing batch job info
    """
    assert(len(prompts_df[MODEL].unique()) == 1)
    model = prompts_df[MODEL].unique()[0]
    prompts, input_file = self.create_message_batch(file_name, prompts_df)
    batch_job = self.start_batch_job(input_file, model)
    input_file = input_file.model_dump(mode="json")
    batch_job = batch_job.model_dump(mode="json")
    return {"batch_job": batch_job, INPUT_FILE: input_file, PROMPTS: prompts, API_CALL: MISTRAL}

create_message_batch(file_name, prompts_df)

Creates batch of messages to send to Mistral API.

Parameters:

Name Type Description Default
file_name str

Name of file in Mistral API to save as.

required
prompts_df DataFrame

DataFrame containing prompts and generation parameters

required

Returns:

Name Type Description
tuple Tuple[List[dict], UploadFileOut]

List of requests sent, UploadFileOut object

Source code in dactyl_generation/mistral_generation.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def create_message_batch(self, file_name: str, prompts_df: pd.DataFrame) -> Tuple[List[dict], mistralai.models.UploadFileOut]:
    """
    Creates batch of messages to send to Mistral API.


    Args:
        file_name: Name of file in Mistral API to save as.
        prompts_df: DataFrame containing prompts and generation parameters

    Returns:
        tuple: List of requests sent, UploadFileOut object
    """

    buffer = BytesIO()
    list_of_requests = list()
    messages = prompts_df.drop(columns=[CUSTOM_ID]).to_dict(orient="records")
    for index, message_batch in enumerate(messages):
        request = {
            CUSTOM_ID: prompts_df[CUSTOM_ID].values[index],
            BODY: message_batch
        }
        list_of_requests.append(request)
        buffer.write((json.dumps(request)+"\n").encode("utf-8"))
    file = File(file_name=file_name, content=buffer.getvalue())
    return list_of_requests, self.client.files.upload(file=file, purpose=BATCH)

get_batch_job_output(file_path)

Gets batch job results using saved metadata from a local JSON file.

Parameters:

Name Type Description Default
file_path str

local JSON file containing output of the create_batch_job function

required

Returns:

Name Type Description
df DataFrame

pandas DataFrame of generations.

Source code in dactyl_generation/mistral_generation.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def get_batch_job_output(self, file_path: str) -> pd.DataFrame:
    """
    Gets batch job results using saved metadata from a local JSON file.

    Args:
        file_path: local JSON file containing output of the `create_batch_job` function

    Returns:
        df: pandas DataFrame of generations.
    """
    with open(file_path, "r") as f:
        data = json.load(f)
    job_id = data["batch_job"]["id"]
    output_file = self.client.batch.jobs.get(job_id=job_id).output_file
    content = self.client.files.download(file_id=output_file).read().decode("utf-8")
    json_obj = "[" + ", ".join(content.splitlines()) + "]"
    responses = json.loads(json_obj)
    rows = list()
    for response in responses:
        row = dict()
        row[CUSTOM_ID] = response[CUSTOM_ID]
        row[TEXT] = response[RESPONSE][BODY][CHOICES][0][MESSAGE][CONTENT]
        row[TIMESTAMP] = str(datetime.fromtimestamp(response[RESPONSE][BODY][CREATED], tz=timezone.utc))
        rows.append(row)
    raw_prompts = pd.DataFrame([{**prompt[BODY], **{CUSTOM_ID: prompt[CUSTOM_ID]}} for prompt in data[PROMPTS]])
    generations = pd.DataFrame(rows)
    return generations.merge(raw_prompts, on=CUSTOM_ID,how="left")

get_batch_jobs()

Helper method to get status of all batch jobs.

Returns:

Name Type Description
batch_jobs_list BatchJobsOut

list of all batch jobs

Source code in dactyl_generation/mistral_generation.py
 98
 99
100
101
102
103
104
105
106
107
def get_batch_jobs(self) -> BatchJobsOut:
    """
    Helper method to get status of all batch jobs.

    Returns:
        batch_jobs_list: list of all batch jobs
    """
    return self.client.batch.jobs.list(
        metadata={"job_type": "testing"}
    )

start_batch_job(input_file, model)

Start batch job from input file stored on Mistral API containing prompts.

Parameters:

Name Type Description Default
input_file UploadFileOut

input file object to create job with

required
model str

model name to use for generation

required

Returns:

Name Type Description
batch_job BatchJobOut

Batch job object

Source code in dactyl_generation/mistral_generation.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def start_batch_job(self, input_file: mistralai.models.UploadFileOut, model: str) -> mistralai.models.BatchJobOut:
    """
    Start batch job from input file stored on Mistral API containing prompts.

    Args:
        input_file: input file object to create job with
        model: model name to use for generation

    Returns:
        batch_job: Batch job object
    """

    batch_job = self.client.batch.jobs.create(
        input_files=[input_file.id],
        model=model,
        endpoint="/v1/chat/completions",
        metadata={"job_type": "testing"}
    )
    return batch_job