diff --git a/psopacket/src/lib.rs b/psopacket/src/lib.rs index ef5a7a9..ff7c63c 100644 --- a/psopacket/src/lib.rs +++ b/psopacket/src/lib.rs @@ -647,19 +647,7 @@ pub fn pso_message(attr: TokenStream, item: TokenStream) -> TokenStream { q.into() } -#[proc_macro_derive(PSOPacketData)] -pub fn pso_packet_data(input: TokenStream) -> TokenStream { - let derive = parse_macro_input!(input as DeriveInput); - - let name = derive.ident; - - let fields = if let syn::Data::Struct(strct) = derive.data { - strct.fields - } - else { - return syn::Error::new(name.span(), "PSOPacketData only works on structs").to_compile_error().into(); - }; - +fn pso_packet_data_struct(name: syn::Ident, fields: syn::Fields) -> TokenStream { let attrs = match get_struct_fields(fields.iter()) { Ok(a) => a, Err(err) => return err @@ -695,3 +683,73 @@ pub fn pso_packet_data(input: TokenStream) -> TokenStream { q.into() } + +fn pso_packet_data_enum<'a>(name: syn::Ident, repr_type: syn::Ident, variants: impl Iterator + Clone) -> TokenStream { + let value_to_variant = variants + .clone() + .enumerate() + .map(|(i, variant)| { + quote! { + #i => #name::#variant, + } + }) + .collect::>(); + + let variant_to_value = variants + .enumerate() + .map(|(i, variant)| { + quote! { + #name::#variant => #repr_type::to_le_bytes(#i as #repr_type).to_vec(), + } + }) + .collect::>(); + let impl_pso_data_packet = quote! { + impl PSOPacketData for #name { + fn from_bytes(mut cur: &mut R) -> Result { + let mut buf = #repr_type::default().to_le_bytes(); + cur.read_exact(&mut buf).unwrap(); + let value = #repr_type::from_le_bytes(buf); + + Ok(match value as usize { + #(#value_to_variant)* + _ => return Err(PacketParseError::InvalidValue) + }) + } + + fn as_bytes(&self) -> Vec<#repr_type> { + match self { + #(#variant_to_value)* + } + } + } + }; + + impl_pso_data_packet.into() +} + +#[proc_macro_derive(PSOPacketData)] +pub fn pso_packet_data(input: TokenStream) -> TokenStream { + let derive = parse_macro_input!(input as DeriveInput); + + let name = derive.ident; + + if let syn::Data::Struct(strct) = derive.data { + pso_packet_data_struct(name, strct.fields) + } + else if let syn::Data::Enum(enm) = derive.data { + let repr_type = derive.attrs.iter().fold(None, |mut repr_type, attr| { + if attr.path().is_ident("repr") { + attr.parse_nested_meta(|meta| { + repr_type = Some(meta.path.get_ident().cloned().unwrap()); + Ok(()) + }).unwrap(); + } + repr_type + }); + + pso_packet_data_enum(name, repr_type.unwrap(), enm.variants.iter()) + } + else { + syn::Error::new(name.span(), "PSOPacketData only works on structs and enums").to_compile_error().into() + } +}