Skip to content

Transformers Engine Parameter Passing Error in GRPO Training(The following model_kwargs are not used by the model) #9131

@wolfvoid

Description

@wolfvoid

Checklist / 检查清单

  • I have searched existing issues, and this is a new feature request. / 我已经搜索过现有的 issues,确认这是一个新的 Feature Request。

Feature Request Description / Feature Request 描述

Issue Description:
────────────────────────────────────────

Environment

  • ms-swift: 4.2.0dev0
  • Python: 3.10
  • PyTorch: 2.10.0
  • vLLM: 0.19.0
  • GPU: NVIDIA H800 * 8
  • DeepSpeed: zero2

────────────────────────────────────────

Problem Description

When performing GRPO training with ms-swift, due to Triton kernel compatibility issues between vLLM V1 and DeepSpeed multi-process setup (see previous context), I disabled
vLLM (--use_vllm false) and switched to the transformers engine for rollout generation.

However, after disabling vLLM, training crashes at the first step with the following error:

ValueError: The following model_kwargs are not used by the model:
['solution', 'solution_rich', 'task_type', 'schema_version', 'source_file',
'split_name', 'prompt_id', 'request_id']

────────────────────────────────────────

Error Stack Trace

[rank0]: Traceback (most recent call last):
[rank0]: File "/userplace2/username/ms-swift/swift/cli/rlhf.py", line 7, in
[rank0]: rlhf_main()
[rank0]: File "/userplace2/username/ms-swift/swift/pipelines/train/rlhf.py", line 246, in rlhf_main
[rank0]: return SwiftRLHF(args).main()
[rank0]: File "/userplace2/username/ms-swift/swift/pipelines/base.py", line 52, in main
[rank0]: result = self.run()
[rank0]: File "/userplace2/username/ms-swift/swift/ray/base.py", line 168, in wrapper
[rank0]: return func(self, *args, **kwargs)
[rank0]: File "/userplace2/username/ms-swift/swift/pipelines/train/sft.py", line 197, in run
[rank0]: return self.train(trainer)
[rank0]: File "/userplace2/username/ms-swift/swift/pipelines/train/sft.py", line 270, in train
[rank0]: trainer.train(resume_checkpoint)
[rank0]: File "/userplace2/username/ms-swift/swift/trainers/mixin.py", line 896, in train
[rank0]: res = super().train(*args, **kwargs)
[rank0]: File "/opt/conda/envs/user_grpo/lib/python3.10/site-packages/transformers/trainer.py", line 1425, in train
[rank0]: return inner_training_loop(
[rank0]: File "/opt/conda/envs/user_grpo/lib/python3.10/site-packages/transformers/trainer.py", line 1507, in _inner_training_loop
[rank0]: self._run_epoch(
[rank0]: File "/opt/conda/envs/user_grpo/lib/python3.10/site-packages/transformers/trainer.py", line 1735, in _run_epoch
[rank0]: tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank0]: File "/userplace2/username/ms-swift/swift/rlhf_trainers/grpo_trainer.py", line 1917, in training_step
[rank0]: return super().training_step(model, inputs, num_items_in_batch)
[rank0]: File "/opt/conda/envs/user_grpo/lib/python3.10/site-packages/trl/trainer/grpo_trainer.py", line 1085, in training_step
[rank0]: output = super().training_step(model, inputs, num_items_in_batch)
[rank0]: File "/opt/conda/envs/user_grpo/lib/python3.10/site-packages/transformers/trainer.py", line 1901, in training_step
[rank0]: inputs = self._prepare_inputs(inputs)
[rank0]: File "/userplace2/username/ms-swift/swift/rlhf_trainers/utils.py", line 613, in wrapper
[rank0]: return func(self, *args, **kwargs)
[rank0]: File "/userplace2/username/ms-swift/swift/rlhf_trainers/grpo_trainer.py", line 203, in _prepare_inputs
[rank0]: generation_batch = self._generate_and_score_completions(generation_batch)
[rank0]: File "/userplace2/username/ms-swift/swift/rlhf_trainers/utils.py", line 613, in wrapper
[rank0]: return func(self, *args, **kwargs)
[rank0]: File "/userplace2/username/ms-swift/swift/rlhf_trainers/grpo_trainer.py", line 236, in _generate_and_score_completions
[rank0]: inputs = self._generate_completions(inputs)
[rank0]: File "/userplace2/username/ms-swift/swift/rlhf_trainers/grpo_trainer.py", line 222, in _generate_completions
[rank0]: results = self._infer_single_or_multi_turn(inputs, self.request_config)
[rank0]: File "/userplace2/username/ms-swift/swift/rlhf_trainers/rollout_mixin.py", line 765, in _infer_single_or_multi_turn
[rank0]: rollout_outputs: List[RolloutOutput] = self._rollout(inputs, request_config, is_global_inputs)
[rank0]: File "/userplace2/username/ms-swift/swift/rlhf_trainers/rollout_mixin.py", line 729, in _rollout
[rank0]: rollout_outputs = self._colocate_rollout(inputs, request_config)
[rank0]: File "/userplace2/username/ms-swift/swift/rlhf_trainers/rollout_mixin.py", line 1088, in _colocate_rollout
[rank0]: outputs: List[RolloutOutput] = self._engine_infer(infer_requests=inputs, request_config=request_config)
[rank0]: File "/userplace2/username/ms-swift/swift/rlhf_trainers/rollout_mixin.py", line 1109, in _engine_infer
[rank0]: res = self.engine.infer(infer_requests, request_config, use_tqdm=use_tqdm)
[rank0]: File "/userplace2/username/ms-swift/swift/infer_engine/transformers_engine.py", line 577, in infer
[rank0]: res += self._infer(infer_requests_samples, request_config, adapter_request=adapter_request)
[rank0]: File "/opt/conda/envs/user_grpo/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 124, in decorate_context
[rank0]: return func(*args, **kwargs)
[rank0]: File "/userplace2/username/ms-swift/swift/infer_engine/transformers_engine.py", line 546, in _infer
[rank0]: res = infer_func(**kwargs)
[rank0]: File "/userplace2/username/ms-swift/swift/infer_engine/transformers_engine.py", line 397, in _infer_full
[rank0]: output = dict(self.template.generate(self.model, **generate_kwargs))
[rank0]: File "/user2/username/ms-swift/swift/template/base.py", line 666, in generate
[rank0]: return model.generate(*args, **kwargs)
[rank0]: File "/opt/conda/envs/user_grpo/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 124, in decorate_context
[rank0]: return func(*args, **kwargs)
[rank0]: File "/opt/conda/envs/user_grpo/lib/python3.10/site-packages/transformers/generation/utils.py", line 2360, in generate
[rank0]: self._validate_model_kwargs(model_kwargs.copy())
[rank0]: File "/opt/conda/envs/user_grpo/lib/python3.10/site-packages/transformers/generation/utils.py", line 1557, in _validate_model_kwargs
[rank0]: raise ValueError(
[rank0]: ValueError: The following model_kwargs are not used by the model: ['solution', 'solution_rich', 'task_type', 'schema_version', 'source_file', 'split_name',
'prompt_id', 'request_id'] (note: typos in the generate arguments will also show up in this list)这次是什么问题?

────────────────────────────────────────

Root Cause

GRPO training data samples contain business fields (such as solution, task_type, prompt_id, etc.) that are used for reward calculation but should not be passed to the model's
generate() method.

In the _infer_full method of transformers_engine.py:

generate_kwargs = {'generation_config': generation_config, **inputs}

This unpacks the entire inputs dictionary (including all data fields) and passes it to generate(), causing transformers to raise an error.

The vLLM engine likely has internal filtering logic, but the transformers engine does not, thus exposing this issue.

────────────────────────────────────────

Solution

Modify swift/infer_engine/transformers_engine.py to dynamically filter non-model parameters in the _infer_full method:

Original code (lines 388-390):
def _infer_full(self, inputs: Dict[str, Any], *, generation_config: GenerationConfig,
adapter_request: Optional[AdapterRequest], request_config: RequestConfig,
template_inputs) -> List[ChatCompletionResponse]:
# bos_token TODO: encoder-decoder
generate_kwargs = {'generation_config': generation_config, **inputs}

Modified code:
def _infer_full(self, inputs: Dict[str, Any], *, generation_config: GenerationConfig,
adapter_request: Optional[AdapterRequest], request_config: RequestConfig,
template_inputs) -> List[ChatCompletionResponse]:
# bos_token TODO: encoder-decoder
# Filter out non-model kwargs (data fields that should not be passed to model.generate)
# Dynamically get model forward parameters for better compatibility
model_forward_params = set(inspect.signature(self.model.forward).parameters.keys())
filtered_inputs = {k: v for k, v in inputs.items() if k in model_forward_params}
generate_kwargs = {'generation_config': generation_config, **filtered_inputs}

Modification logic:

  1. Use inspect.signature(self.model.forward).parameters to dynamically get all parameter names accepted by the model's forward method
  2. Filter inputs to only keep parameters the model can accept
  3. Pass the filtered parameters to generate()

This ensures that regardless of what business fields the dataset contains, they won't be incorrectly passed to the model.

────────────────────────────────────────

Verification

After the modification, training runs normally:

{'loss': '0.02349', 'grad_norm': '1.644', 'reward': '0.8742', ...}
{'loss': '-0.0999', 'grad_norm': '2.111', 'reward': '0.8517', ...}

────────────────────────────────────────

Recommendation

It is recommended that ms-swift add this filtering logic by default in the transformers engine to improve compatibility with scenarios like GRPO.

relative issues

#5157 mentions the problem but seems not solved but closed

#9119 meets the problem as well, hope can help you

Pull Request / Pull Request 信息

if the issue is correct, hope the maintainers can fix the problem.
However i am not sure if my solution equals to the right method.

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions