-
Notifications
You must be signed in to change notification settings - Fork 6k
Add safety module #213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add safety module #213
Conversation
The documentation is not available anymore as the PR was closed or merged. |
path = module.__module__.split(".") | ||
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More general logic to detect if a module comes from pipeline
module. For now this is only needed for LDMBert
model and the safety checker.
This should probably be in a separate PR.
safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(torch_device) | ||
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
safety_checker
will replace nsfw images (if detected) in image
with a black image.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great! Looks very nice to me!
LGTM! I'll test it when I'm added to the |
Added you :-) |
Looks great! Just a couple of quick comments / questions:
|
Regarding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Our tests are all based on v1-1, but I've checked manually with the safety checkpoint.
One nit: maybe let's add a logger.warning()
about detected nsfw still, so that it's visible in a terminal/notebooks too.
The "rerun with a different seed" message can be included in the warning as well, and added to the gradio demo too.
Hey @BIGJUN777, Could you please open a new issue for this? |
It would be nice to have a reproducible code snippet as well |
* add SafetyChecker * better name, fix checker * add checker in main init * remove from main init * update logic to detect pipeline module * style * handle all safety logic in safety checker * draw text * can't draw * small fixes * treat special care as nsfw * remove commented lines * update safety checker
This PR adds
StableDiffusionSafetyChecker
to filter out NSFW content inStableDiffusionPipeline
.The
StableDiffusionSafetyChecker
contains the clip vision model, vision projection. The nsfw concept embeddings are pre-computed and part of the state_dict. The module takes images and replaces nsfew image if detected with a black image.It's added as a new required attribute in
StableDiffusionPipeline
so it'll always be loaded.