Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: logger accepts arrays as data and use tensorboard_logger Logger #131

Merged
merged 3 commits into from
Dec 4, 2024

Conversation

Your-Cheese
Copy link
Contributor

What?

I was running experiments with MuZero when I ran into some errors with the loggers. I made changes so that the logger accepts Jax arrays as data to log and uses a tensorboard_logger Logger object instead of a global variable.

Why?

Somehow, a Jax array gets passed as input to the logger every so often, resulting in an error.
When using Hydra to perform multiple runs with the Tensorboard logger enabled, an error occurs on the second run due to the logger from the previous run existing and already being configured because it is based on a global variable in the tensorboard_logger library.

How?

In the case that a Jax array gets passed to the logger, it converts it to a scalar.
By using an instance of the Logger class instead of a global variable, the Tensorboard logger properly goes out of scope between runs.

Extra

I also had issues with running the commit linter step when making the commit because it would say "Please add rules to your commitlint.config.js". Making this change to the dependencies resolved it. I'm not sure if it's an issue on my end or not, although I did follow the steps for installing the pre-commit hooks.

@EdanToledo
Copy link
Owner

Hello, thanks so much for spotting this and offering a fix. I'm relatively happy with the tensorboard change however this issue of a jax array being passed seems as if it would be common for all loggers. In that case, we should be making a fix in the StoixLogger class as that filters down to all the loggers being used. That way, the fix will be present for all the other loggers. If you're happy to make changes to this PR, I'd like the following:

  1. Remove the if statements and .item() calls from the individual logger.
  2. Perform a check like this after the np.mean or the "describe" function call in the stoix logger
  3. check to see if the problem is fixed.

@Your-Cheese
Copy link
Contributor Author

I really appreciate the feedback and being so specific on what to change. I moved the check for jax arrays to the StoixLogger log function, and it seems to work fine.

stoix/utils/logger.py Outdated Show resolved Hide resolved
.item() if value is either jax array or numpy array
@EdanToledo
Copy link
Owner

Thanks so much for the changes. Once the tests pass, I'll merge.

Copy link
Owner

@EdanToledo EdanToledo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@EdanToledo EdanToledo merged commit 0dcf410 into EdanToledo:main Dec 4, 2024
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants