Skip to content

Commit d51ee0d

Browse files
authored
Merge pull request #4843 from valkrypton/boolean-to-control-0-variadic-variants
added boolean argument in the #[variadic] macro to control generation of 0-variadic-variants
2 parents 143b254 + 4c8dc55 commit d51ee0d

File tree

5 files changed

+141
-23
lines changed

5 files changed

+141
-23
lines changed

diesel_compile_tests/tests/fail/derive/bad_variadic.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@ extern "SQL" {
77
//~^ ERROR: invalid ABI: found `SQL`
88

99
#[variadic(not_a_literal_number)]
10-
//~^ ERROR: expected integer literal, the correct format is `#[variadic(3)]`
10+
//~^ ERROR: expect `last_arguments`, the correct format is `#[variadic(last_arguments = 3)]` or `#[variadic(last_arguments = 3, skip_zero_argument_variant = true)]`
1111
//~| ERROR: cannot find attribute `variadic` in this scope
1212
fn f();
1313
}
1414

1515
#[declare_sql_function]
1616
extern "SQL" {
17-
#[variadic(3)]
17+
#[variadic(last_arguments = 3)]
1818
//~^ ERROR: invalid variadic argument count: not enough function arguments
1919
fn g<A: SqlType>(a: A);
2020
}

diesel_compile_tests/tests/fail/derive/bad_variadic.stderr

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1-
error: expected integer literal, the correct format is `#[variadic(3)]`
1+
error: expect `last_arguments`, the correct format is `#[variadic(last_arguments = 3)]` or `#[variadic(last_arguments = 3, skip_zero_argument_variant = true)]`
22
--> tests/fail/derive/bad_variadic.rs:9:16
33
|
44
LL | #[variadic(not_a_literal_number)]
55
| ^^^^^^^^^^^^^^^^^^^^
66

77
error: invalid variadic argument count: not enough function arguments
8-
--> tests/fail/derive/bad_variadic.rs:17:16
8+
--> tests/fail/derive/bad_variadic.rs:17:33
99
|
10-
LL | #[variadic(3)]
11-
| ^
10+
LL | #[variadic(last_arguments = 3)]
11+
| ^
1212

1313
error: cannot find attribute `variadic` in this scope
1414
--> tests/fail/derive/bad_variadic.rs:9:7

diesel_compile_tests/tests/fail/derive/return_type_helper_errors.rs

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,15 @@ mod with_return_type_helpers {
1919
fn f<A: SingleValue>(a: <A as TypeWrapper>::Type);
2020
//~^ ERROR: cannot find argument corresponding to the generic
2121

22-
#[variadic(1)]
22+
#[variadic(last_arguments = 1)]
2323
fn g<A: SingleValue>(a: <A as TypeWrapper>::Type);
2424
//~^ ERROR: cannot find argument corresponding to the generic
2525

2626
#[skip_return_type_helper]
2727
fn h<A: SingleValue>(a: <A as TypeWrapper>::Type);
2828

2929
#[skip_return_type_helper]
30-
#[variadic(1)]
30+
#[variadic(last_arguments = 1)]
3131
fn i<A: SingleValue>(a: <A as TypeWrapper>::Type);
3232
}
3333
}
@@ -39,9 +39,21 @@ mod without_return_type_helpers {
3939
extern "SQL" {
4040
fn f<A: SingleValue>(a: <A as TypeWrapper>::Type);
4141

42-
#[variadic(1)]
42+
#[variadic(last_arguments = 1)]
4343
fn g<A: SingleValue>(a: <A as TypeWrapper>::Type);
4444
}
4545
}
4646

47+
mod backward_compatibility_test {
48+
use super::*;
49+
50+
#[declare_sql_function]
51+
extern "SQL" {
52+
// this should not generate any error to stay backward compatible with
53+
// diesel 2.3
54+
#[variadic(1)]
55+
fn f<A: SingleValue>(a: <A as TypeWrapper>::Type);
56+
}
57+
}
58+
4759
fn main() {}

diesel_derives/src/lib.rs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2385,12 +2385,46 @@ const AUTO_TYPE_DEFAULT_FUNCTION_TYPE_CASE: dsl_auto_type::Case = dsl_auto_type:
23852385
/// }
23862386
/// ```
23872387
///
2388+
/// Optionally, a second named boolean argument `skip_zero_argument_variant` can be provided to
2389+
/// control whether the 0-argument variant is generated. By default, (omitted or `false`),
2390+
/// the 0-argument variant is included. Set it to `true` to skip generating the 0-argument
2391+
/// variant for functions that require at least one variadic argument. If you specify the boolean
2392+
/// argument, the first argument has to be named `last_arguments` for clarity.
2393+
///
2394+
/// Example:
2395+
///
2396+
/// ```ignore
2397+
/// #[declare_sql_function]
2398+
/// extern "SQL" {
2399+
/// #[variadic(last_arguments = 2, skip_zero_argument_variant = true)]
2400+
/// fn foo<A, B, C>(a: A, b: B, c: C) -> Text;
2401+
/// }
2402+
/// ```
2403+
///
2404+
/// Which will be equivalent to
2405+
///
2406+
/// ```ignore
2407+
/// #[declare_sql_function]
2408+
/// extern "SQL" {
2409+
/// #[sql_name = "foo"]
2410+
/// fn foo_1<A, B1, C1>(a: A, b_1: B1, c_1: C1) -> Text;
2411+
///
2412+
/// #[sql_name = "foo"]
2413+
/// fn foo_2<A, B1, C1, B2, C2>(a: A, b_1: B1, c_1: C1, b_2: B2, c_2: C2) -> Text;
2414+
///
2415+
/// ...
2416+
/// }
2417+
/// ```
2418+
///
23882419
/// ### Controlling the generation of variadic function variants
23892420
///
23902421
/// By default, only variants with 0, 1, and 2 repetitions of variadic arguments are generated. To
23912422
/// generate more variants, set the `DIESEL_VARIADIC_FUNCTION_ARGS` environment variable to the
23922423
/// desired number of variants.
23932424
///
2425+
/// • The boolean only affects whether the 0 variant is generated; the total number of variants
2426+
/// (e.g., up to N) still follows DIESEL_VARIADIC_FUNCTION_ARGS or the default.
2427+
///
23942428
/// For a greater convenience this environment variable can also be set in a `.cargo/config.toml`
23952429
/// file as described in the [cargo documentation](https://doc.rust-lang.org/cargo/reference/config.html#env).
23962430
///

diesel_derives/src/sql_function.rs

Lines changed: 86 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -84,14 +84,15 @@ fn expand_one(
8484
let attributes = &mut input.attributes;
8585

8686
let variadic_argument_count = attributes.iter().find_map(|attr| {
87-
if let SqlFunctionAttribute::Variadic(_, c) = &attr.item {
88-
Some((c.base10_parse(), c.span()))
87+
if let SqlFunctionAttribute::Variadic(_, c, flag) = &attr.item {
88+
Some((c.base10_parse(), c.span(), flag.value))
8989
} else {
9090
None
9191
}
9292
});
9393

94-
let Some((variadic_argument_count, variadic_span)) = variadic_argument_count else {
94+
let Some((variadic_argument_count, variadic_span, non_zero_variadic)) = variadic_argument_count
95+
else {
9596
let sql_name = parse_sql_name_attr(&mut input);
9697

9798
return expand_nonvariadic(
@@ -103,14 +104,15 @@ fn expand_one(
103104
};
104105

105106
let variadic_argument_count = variadic_argument_count?;
107+
let start_idx = if non_zero_variadic { 1 } else { 0 };
106108

107109
let variadic_variants = VARIADIC_ARG_COUNT_ENV
108110
.and_then(|arg_count| arg_count.parse::<usize>().ok())
109111
.unwrap_or(VARIADIC_VARIANTS_DEFAULT);
110112

111113
let mut result = TokenStream::new();
112114
let mut helper_type_modules = vec![];
113-
for variant_no in 0..=variadic_variants {
115+
for variant_no in start_idx..=variadic_variants {
114116
let expanded = expand_variadic(
115117
input.clone(),
116118
legacy_helper_type_and_module,
@@ -1054,7 +1056,7 @@ fn merge_attributes(
10541056
(SqlFunctionAttribute::Window { .. }, SqlFunctionAttribute::Window { .. })
10551057
| (SqlFunctionAttribute::SqlName(_, _), SqlFunctionAttribute::SqlName(_, _))
10561058
| (SqlFunctionAttribute::Restriction(_), SqlFunctionAttribute::Restriction(_))
1057-
| (SqlFunctionAttribute::Variadic(_, _), SqlFunctionAttribute::Variadic(_, _))
1059+
| (SqlFunctionAttribute::Variadic(_, _, _), SqlFunctionAttribute::Variadic(_, _, _))
10581060
| (
10591061
SqlFunctionAttribute::SkipReturnTypeHelper(_),
10601062
SqlFunctionAttribute::SkipReturnTypeHelper(_),
@@ -1217,14 +1219,50 @@ fn parse_attribute(
12171219
syn::Meta::List(syn::MetaList {
12181220
path,
12191221
delimiter: syn::MacroDelimiter::Paren(_),
1220-
tokens,
1222+
tokens: _,
12211223
}) if path.is_ident("variadic") => {
1222-
let count: syn::LitInt = syn::parse2(tokens.clone()).map_err(|e| {
1223-
syn::Error::new(
1224-
e.span(),
1225-
format!("{e}, the correct format is `#[variadic(3)]`"),
1226-
)
1227-
})?;
1224+
let (count, flag) = attr
1225+
.parse_args_with(|input: syn::parse::ParseStream| {
1226+
if input.peek(LitInt){
1227+
let count = input.parse::<LitInt>()?;
1228+
if !input.is_empty(){
1229+
return Err(syn::Error::new(input.span(), "unexpected token after positional `#[variadic(..)]`"));
1230+
}
1231+
Ok((count, LitBool::new(false, Span::call_site())))
1232+
}
1233+
else {
1234+
let key: Ident = input.parse()?;
1235+
if key != "last_arguments" {
1236+
return Err(syn::Error::new(key.span(), "expect `last_arguments`"));
1237+
}
1238+
let _eq: Token![=] = input.parse()?;
1239+
let count: LitInt = input.parse()?;
1240+
let skip_zero: LitBool = if input.peek(Token![,]) {
1241+
let _: Token![,] = input.parse()?;
1242+
let key: Ident = input.parse()?;
1243+
if key != "skip_zero_argument_variant" {
1244+
return Err(
1245+
syn::Error::new(
1246+
key.span(), "expect `skip_zero_argument_variant`"
1247+
)
1248+
);
1249+
}
1250+
let _eq: Token![=] = input.parse()?;
1251+
input.parse()?
1252+
} else {
1253+
LitBool::new(false, Span::call_site())
1254+
};
1255+
Ok((count, skip_zero))
1256+
}
1257+
})
1258+
.map_err(|e| {
1259+
syn::Error::new(
1260+
e.span(),
1261+
format!(
1262+
"{e}, the correct format is `#[variadic(last_arguments = 3)]` or `#[variadic(last_arguments = 3, skip_zero_argument_variant = true)]`"
1263+
),
1264+
)
1265+
})?;
12281266
Ok(Some(AttributeSpanWrapper {
12291267
item: SqlFunctionAttribute::Variadic(
12301268
path.require_ident()
@@ -1236,6 +1274,7 @@ fn parse_attribute(
12361274
})?
12371275
.clone(),
12381276
count.clone(),
1277+
flag,
12391278
),
12401279
attribute_span: attr.span(),
12411280
ident_span: path.require_ident()?.span(),
@@ -1648,7 +1687,7 @@ enum SqlFunctionAttribute {
16481687
},
16491688
SqlName(Ident, LitStr),
16501689
Restriction(BackendRestriction),
1651-
Variadic(Ident, LitInt),
1690+
Variadic(Ident, LitInt, LitBool),
16521691
SkipReturnTypeHelper(Ident),
16531692
Other(Attribute),
16541693
}
@@ -1773,7 +1812,40 @@ impl SqlFunctionAttribute {
17731812
"backend_bounds" => {
17741813
BackendRestriction::parse_backend_bounds(input, name).map(Self::Restriction)?
17751814
}
1776-
"variadic" => Self::Variadic(name, input.parse()?),
1815+
"variadic" => {
1816+
if input.peek(LitInt) {
1817+
let count = input.parse::<LitInt>()?;
1818+
if !input.is_empty() {
1819+
return Err(syn::Error::new(
1820+
input.span(),
1821+
"unexpected token after positional `#[variadic(..)]`",
1822+
));
1823+
}
1824+
Self::Variadic(name, count, LitBool::new(false, Span::call_site()))
1825+
} else {
1826+
let key: Ident = input.parse()?;
1827+
if key != "last_arguments" {
1828+
return Err(syn::Error::new(key.span(), "expect `last_arguments`"));
1829+
}
1830+
let _eq: Token![=] = input.parse()?;
1831+
let count: LitInt = input.parse()?;
1832+
let skip_zero: LitBool = if input.peek(Token![,]) {
1833+
let _: Token![,] = input.parse()?;
1834+
let key: Ident = input.parse()?;
1835+
if key != "skip_zero_argument_variant" {
1836+
return Err(syn::Error::new(
1837+
key.span(),
1838+
"expect `skip_zero_argument_variant`",
1839+
));
1840+
}
1841+
let _eq: Token![=] = input.parse()?;
1842+
input.parse()?
1843+
} else {
1844+
LitBool::new(false, Span::call_site())
1845+
};
1846+
Self::Variadic(name, count, skip_zero)
1847+
}
1848+
}
17771849
_ => {
17781850
// empty the parse buffer otherwise syn will return an error
17791851
let _ = input.step(|cursor| {

0 commit comments

Comments
 (0)