Skip to content

Commit

Permalink
fix device to host copy not sync stream in logistic regression mg (#5766
Browse files Browse the repository at this point in the history
)

Authors:
  - Jinfeng Li (https://github.com/lijinf2)

Approvers:
  - Micka (https://github.com/lowener)
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: #5766
  • Loading branch information
lijinf2 authored Mar 20, 2024
1 parent 334d796 commit 24bf99b
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
2 changes: 2 additions & 0 deletions cpp/src/glm/qn_mg.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ std::vector<T> distinct_mg(const raft::handle_t& handle, T* y, size_t n)

std::vector<size_t> recv_counts_host(n_ranks);
raft::copy(recv_counts_host.data(), recv_counts.data(), n_ranks, stream);
raft::resource::sync_stream(handle);

std::vector<size_t> displs(n_ranks);
size_t pos = 0;
Expand All @@ -88,6 +89,7 @@ std::vector<T> distinct_mg(const raft::handle_t& handle, T* y, size_t n)

std::vector<T> global_unique_y_host(global_unique_y.size());
raft::copy(global_unique_y_host.data(), global_unique_y.data(), global_unique_y.size(), stream);
raft::resource::sync_stream(handle);

return global_unique_y_host;
}
Expand Down
9 changes: 3 additions & 6 deletions python/cuml/tests/dask/test_dask_logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,9 +720,9 @@ def test_standardization_on_scaled_dataset(
penalty = regularization[0]
C = regularization[1]
l1_ratio = regularization[2]
nrows = int(1e5)
ncols = ncol_and_nclasses[0]
n_classes = ncol_and_nclasses[1]
nrows = int(1e5) if n_classes < 5 else int(2e5)
ncols = ncol_and_nclasses[0]
n_info = ncols
n_redundant = 0
n_parts = 2
Expand Down Expand Up @@ -784,6 +784,7 @@ def to_dask_data(X_train, X_test, y_train, y_test):
# if fit_intercept is false, scale the dataset without mean center
scaler = StandardScaler(with_mean=fit_intercept, with_std=True)
scaler.fit(X_train)
scaler.scale_ = np.sqrt(scaler.var_ * len(X_train) / (len(X_train) - 1))
X_train_scaled = scaler.transform(X_train)
X_test_scaled = scaler.transform(X_test)

Expand Down Expand Up @@ -842,10 +843,6 @@ def to_dask_data(X_train, X_test, y_train, y_test):
np.abs(mgon_accuracy - mgoff_accuracy) < 1e-3
)

print(f"mgon_coef_origin: {mgon_coef_origin}")
print(f"mgoff.coef_: {mgoff.coef_.to_numpy()}")
print(f"mgon_intercept_origin: {mgon_intercept_origin}")
print(f"mgoff.intercept_: {mgoff.intercept_.to_numpy()}")
assert array_equal(mgon_coef_origin, mgoff.coef_.to_numpy(), tolerance)
assert array_equal(
mgon_intercept_origin, mgoff.intercept_.to_numpy(), tolerance
Expand Down

0 comments on commit 24bf99b

Please sign in to comment.